Some matrix multiplication tricks

Published: March 04, 2016

The first two are sort of obvious:

We want to multiply some matrix \(\mathbf{A}\) with some diagonal matrix \(\mathbf{D}\). Instead of carrying out a bunch of multiplications by zero, we replace the matrix multiplications by a for loop and elementwise multiplication. In other words,

\begin{equation} \mathbf{D}\mathbf{A} = \begin{bmatrix} &&\\ \mathbf{d} \odot \mathbf{a}_{1} & \mathbf{d} \odot \mathbf{a}_{2} & \cdots \mathbf{d} \odot \mathbf{a}_{n} \\ &&\\ \end{bmatrix}, \end{equation}

where \(\mathbf{a}_n\) is the \(n\) -th column of \(\mathbf{A}\), \(\mathbf{D} = \mathrm{diag}(\mathbf{d})\), and \(\odot\) denotes the elementwise product. In GNU Octave, this would look like the following:

N = 3;
A = randi([0 9], N, N);
d = randi([0 9], N, 1);
D = diag(d);
DA1 = D*A;
DA2 = A;
for k = 1:N
    DA2(:, k)  = DA2(:, k) .* d;
end
if ~all(DA2(:) == DA1(:))
    s = 'NOT equal';
else
    s = 'equal';
end
ans = sprintf('DA1 and DA2 are %s.', s);

The output for the above snippet is, as is to be expected:

DA1 and DA2 are equal.

We do a similar thing when multiplying with a diagonal matrix on the right, i.e. \(\mathbf{A}\mathbf{D}\). The difference here is that we're pointwise multiplying the rows instead of the columns of \(\mathbf{A}\):

N = 3;
A = randi([0 9], N, N);
d = randi([0 9], 1, N); % this is now a row vector
D = diag(d);
AD1 = A*D;
AD2 = A;
for k = 1:N
    # now pointwise multiplying the rows
    AD2(k, :)  = AD2(k, :) .* d;
end
if ~all(AD2(:) == AD1(:))
    s = 'NOT equal';
else
    s = 'equal';
end
ans = sprintf('AD1 and AD2 are %s.', s);

We again get a similar output:

AD1 and AD2 are equal.

Now, suppose you find yourself in the need for only the diagonal elements of a matrix, i.e. you are looking for something along the lines of

\begin{equation} \mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) \end{equation}

(Note that I am using the \(\mathrm{diag}(\cdot)\) operator here in the MATLAB/Octave sense: When applied to a vector, it returns a diagonal matrix, but when applied to a matrix, it returns the matrix' diagonal elements. Also, \(\mathbf{A}\) and \(\mathbf{B}\) are any two matrices with the correct dimensions.)

We can then save some computations by doing

\begin{equation} \mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) = \mathtt{sum}\left( \mathbf{A \odot \mathbf{B}^{\mathsf{T}}}, \mathtt{2} \right), \end{equation}

Which means that we first pointwise multiply by the transpose of \(\mathbf{B}\) and then sum across the rows. I apologize for my abuse in notation here, but the following snippet makes the operation a little clearer:

N = 3;
A = randi([0 9], N, N);
B = randi([0 9], N, N);
x1 = diag(A*B);
x2 = sum(A.*B.', 2);
if ~all(x1 == x2)
    s = 'NOT equal';
else
    s = 'equal';
end
ans = sprintf('x1 and x2 are %s.', s);

And of course, we obtain the expected result:

x1 and x2 are equal.