Some matrix multiplication tricks
Published: March 04, 2016
The first two are sort of obvious:
We want to multiply some matrix \(\mathbf{A}\) with some diagonal matrix
\(\mathbf{D}\). Instead of carrying out a bunch of multiplications by zero, we
replace the matrix multiplications by a for
loop and elementwise
multiplication. In other words,
where \(\mathbf{a}_n\) is the \(n\) -th column of \(\mathbf{A}\), \(\mathbf{D} = \mathrm{diag}(\mathbf{d})\), and \(\odot\) denotes the elementwise product. In GNU Octave, this would look like the following:
N = 3; A = randi([0 9], N, N); d = randi([0 9], N, 1); D = diag(d); DA1 = D*A; DA2 = A; for k = 1:N DA2(:, k) = DA2(:, k) .* d; end if ~all(DA2(:) == DA1(:)) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('DA1 and DA2 are %s.', s);
The output for the above snippet is, as is to be expected:
DA1 and DA2 are equal.
We do a similar thing when multiplying with a diagonal matrix on the right, i.e. \(\mathbf{A}\mathbf{D}\). The difference here is that we're pointwise multiplying the rows instead of the columns of \(\mathbf{A}\):
N = 3; A = randi([0 9], N, N); d = randi([0 9], 1, N); % this is now a row vector D = diag(d); AD1 = A*D; AD2 = A; for k = 1:N # now pointwise multiplying the rows AD2(k, :) = AD2(k, :) .* d; end if ~all(AD2(:) == AD1(:)) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('AD1 and AD2 are %s.', s);
We again get a similar output:
AD1 and AD2 are equal.
Now, suppose you find yourself in the need for only the diagonal elements of a matrix, i.e. you are looking for something along the lines of
\begin{equation} \mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) \end{equation}(Note that I am using the \(\mathrm{diag}(\cdot)\) operator here in the MATLAB/Octave sense: When applied to a vector, it returns a diagonal matrix, but when applied to a matrix, it returns the matrix' diagonal elements. Also, \(\mathbf{A}\) and \(\mathbf{B}\) are any two matrices with the correct dimensions.)
We can then save some computations by doing
\begin{equation} \mathbf{x} = \mathrm{diag} \left( \mathbf{A} \mathbf{B} \right) = \mathtt{sum}\left( \mathbf{A \odot \mathbf{B}^{\mathsf{T}}}, \mathtt{2} \right), \end{equation}Which means that we first pointwise multiply by the transpose of \(\mathbf{B}\) and then sum across the rows. I apologize for my abuse in notation here, but the following snippet makes the operation a little clearer:
N = 3; A = randi([0 9], N, N); B = randi([0 9], N, N); x1 = diag(A*B); x2 = sum(A.*B.', 2); if ~all(x1 == x2) s = 'NOT equal'; else s = 'equal'; end ans = sprintf('x1 and x2 are %s.', s);
And of course, we obtain the expected result:
x1 and x2 are equal.