# Some matrix multiplication tricks

The first two are sort of obvious:

We want to multiply some matrix $$\mathbf{A}$$ with some diagonal matrix $$\mathbf{D}$$. Instead of carrying out a bunch of multiplications by zero, we replace the matrix multiplications by a for loop and elementwise multiplication. In other words,

$$\mathbf{D}\mathbf{A} = \begin{bmatrix} &&\\ \mathbf{d} \odot \mathbf{a}_{1} & \mathbf{d} \odot \mathbf{a}_{2} & \cdots \mathbf{d} \odot \mathbf{a}_{n} \\ &&\\ \end{bmatrix},$$

where $$\mathbf{a}_n$$ is the $$n$$ -th column of $$\mathbf{A}$$, $$\mathbf{D} = \mathrm{diag}(\mathbf{d})$$, and $$\odot$$ denotes the elementwise product. In GNU Octave, this would look like the following:

N = 3;
A = randi([0 9], N, N);
d = randi([0 9], N, 1);
D = diag(d);
DA1 = D*A;
DA2 = A;
for k = 1:N
DA2(:, k)  = DA2(:, k) .* d;
end
if ~all(DA2(:) == DA1(:))
s = 'NOT equal';
else
s = 'equal';
end
ans = sprintf('DA1 and DA2 are %s.', s);


The output for the above snippet is, as is to be expected:

DA1 and DA2 are equal.


We do a similar thing when multiplying with a diagonal matrix on the right, i.e. $$\mathbf{A}\mathbf{D}$$. The difference here is that we're pointwise multiplying the rows instead of the columns of $$\mathbf{A}$$:

N = 3;
A = randi([0 9], N, N);
d = randi([0 9], 1, N); % this is now a row vector
D = diag(d);
for k = 1:N
# now pointwise multiplying the rows
end
s = 'NOT equal';
else
s = 'equal';
end


We again get a similar output:

AD1 and AD2 are equal.


Now, suppose you find yourself in the need for only the diagonal elements of a matrix, i.e. you are looking for something along the lines of

$$\mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right)$$

(Note that I am using the $$\mathrm{diag}(\cdot)$$ operator here in the MATLAB/Octave sense: When applied to a vector, it returns a diagonal matrix, but when applied to a matrix, it returns the matrix' diagonal elements. Also, $$\mathbf{A}$$ and $$\mathbf{B}$$ are any two matrices with the correct dimensions.)

We can then save some computations by doing

$$\mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) = \mathtt{sum}\left( \mathbf{A \odot \mathbf{B}^{\mathsf{T}}}, \mathtt{2} \right),$$

Which means that we first pointwise multiply by the transpose of $$\mathbf{B}$$ and then sum across the rows. I apologize for my abuse in notation here, but the following snippet makes the operation a little clearer:

N = 3;
A = randi([0 9], N, N);
B = randi([0 9], N, N);
x1 = diag(A*B);
x2 = sum(A.*B.', 2);
if ~all(x1 == x2)
s = 'NOT equal';
else
s = 'equal';
end
ans = sprintf('x1 and x2 are %s.', s);


And of course, we obtain the expected result:

x1 and x2 are equal.