3 Dimensional Matrix Multiplication
3 Dimensional Matrix Multiplication
I have 3 dimensional matrix S with dimensions [A, R, M], and i have 2 dimensional kernel K with dimensions [M, M].
I need to generate the output array Q with dimensions [A, R], and Qij element of this array is calculated as Sij: * K * Sij: ** H (Hermitian conjugated)
I can do this brute force by doing gemm on flattened S' with dimensions [A*R, M], and then extract diagonal elements from the resulting 2D array, but this is not efficient.
Any idea how this could be done better?
I understand this question is probably more appropriate to ask for NVIDIA and SO people, but wanted to start from here.
Thanks!
I need to generate the output array Q with dimensions [A, R], and Qij element of this array is calculated as Sij: * K * Sij: ** H (Hermitian conjugated)
I can do this brute force by doing gemm on flattened S' with dimensions [A*R, M], and then extract diagonal elements from the resulting 2D array, but this is not efficient.
Any idea how this could be done better?
I understand this question is probably more appropriate to ask for NVIDIA and SO people, but wanted to start from here.
Thanks!
Re: 3 Dimensional Matrix Multiplication
For concreteness, does the below Matlab code do what you want? If so, it can be done with a gemm + batched dot products. There may be something in MAGMA that can help with the batched dot products.
-mark
Code: Select all
% arbitrary dimensions
a = 2;
r = 3;
m = 4;
S = rand( a, r, m );
K = rand( m, m );
% naive implementation
Q = zeros( a, r );
for i = 1:a
for j = 1:r
x = reshape( S(i,j,:), [m,1] );
Q(i,j) = x' * K * x;
end
end
% gemm + dot products
Q2 = zeros( a*r, 1 );
S2 = reshape( S, [a*r, m] );
tmp = S2*K; % gemm
for i = 1:a*r
Q2(i) = tmp(i,:) * S2(i,:)';
end
Q2 = reshape( Q2, [a, r] );
Re: 3 Dimensional Matrix Multiplication
thanks a lot!
Couple of questions:
for the first part, gemm side,
aren't we doing here a bunch of un-necessary processing?
for the batch dot product - forgive my ignorance please, but i do not see how to plug in batch processing instead of the loop
Couple of questions:
for the first part, gemm side,
Code: Select all
tmp = S2*K; % gemmfor the batch dot product - forgive my ignorance please, but i do not see how to plug in batch processing instead of the loop
Code: Select all
for i = 1:a*r
Q2(i) = tmp(i,:) * S2(i,:)';
endRe: 3 Dimensional Matrix Multiplication
I think both the naive and the gemm implementation do the same computations, taking 2AR(M^2) + ARM - AR flops.
Is that the operation you are trying to do?
If you multiply (tmp * S2’), instead of the loop of dots, then you would be doing a lot of extra work — you only need the diagonal entries, as you said.
A batch dot would, for instance, take the two arrays tmp and S2 and dot all their rows (or cols). So a single function call would replace that loop over dots. However, after further investigating, it doesn’t appear that MAGMA currently has the batch dot that you would need. It shouldn’t be a hard kernel to write, though.
Is that the operation you are trying to do?
If you multiply (tmp * S2’), instead of the loop of dots, then you would be doing a lot of extra work — you only need the diagonal entries, as you said.
A batch dot would, for instance, take the two arrays tmp and S2 and dot all their rows (or cols). So a single function call would replace that loop over dots. However, after further investigating, it doesn’t appear that MAGMA currently has the batch dot that you would need. It shouldn’t be a hard kernel to write, though.
Re: 3 Dimensional Matrix Multiplication
sorry, i do not see "obvious" way to write a kernel :(
Re: 3 Dimensional Matrix Multiplication
Here's a simple implementation. Only very lightly tested. Probably needs modifications for your specific environment and purposes, but should be a good starting point.
-mark
-mark
- Attachments
-
- batch_dot.tar.gz
- (1.46 KiB) Downloaded 249 times