# Some matrix multiplication tricks

*Published: March 04, 2016*

The first two are sort of obvious:

We want to multiply some matrix \(\mathbf{A}\) with some diagonal matrix
\(\mathbf{D}\). Instead of carrying out a bunch of multiplications by zero, we
replace the matrix multiplications by a `for`

loop and elementwise
multiplication. In other words,

where \(\mathbf{a}_n\) is the \(n\) -th column of \(\mathbf{A}\), \(\mathbf{D} = \mathrm{diag}(\mathbf{d})\), and \(\odot\) denotes the elementwise product. In GNU Octave, this would look like the following:

N = 3; A = randi([0 9], N, N); d = randi([0 9], N, 1); D = diag(d); DA1 = D*A; DA2 = A; for k = 1:N DA2(:, k) = DA2(:, k) .* d; end if ~all(DA2(:) == DA1(:)) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('DA1 and DA2 are %s.', s);

The output for the above snippet is, as is to be expected:

DA1 and DA2 are equal.

We do a similar thing when multiplying with a diagonal matrix on the right,
i.e. \(\mathbf{A}\mathbf{D}\). The difference here is that we're pointwise
multiplying the **rows** instead of the columns of \(\mathbf{A}\):

N = 3; A = randi([0 9], N, N); d = randi([0 9], 1, N); % this is now a row vector D = diag(d); AD1 = A*D; AD2 = A; for k = 1:N # now pointwise multiplying the rows AD2(k, :) = AD2(k, :) .* d; end if ~all(AD2(:) == AD1(:)) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('AD1 and AD2 are %s.', s);

We again get a similar output:

AD1 and AD2 are equal.

Now, suppose you find yourself in the need for **only the diagonal elements of a
matrix**, i.e. you are looking for something along the lines of

(Note that I am using the \(\mathrm{diag}(\cdot)\) operator here in the MATLAB/Octave sense: When applied to a vector, it returns a diagonal matrix, but when applied to a matrix, it returns the matrix' diagonal elements. Also, \(\mathbf{A}\) and \(\mathbf{B}\) are any two matrices with the correct dimensions.)

We can then save some computations by doing

\begin{equation} \mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) = \mathtt{sum}\left( \mathbf{A \odot \mathbf{B}^{\mathsf{T}}}, \mathtt{2} \right), \end{equation}Which means that we first pointwise multiply by the transpose of \(\mathbf{B}\) and then sum across the rows. I apologize for my abuse in notation here, but the following snippet makes the operation a little clearer:

N = 3; A = randi([0 9], N, N); B = randi([0 9], N, N); x1 = diag(A*B); x2 = sum(A.*B.', 2); if ~all(x1 == x2) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('x1 and x2 are %s.', s);

And of course, we obtain the expected result:

x1 and x2 are equal.