/// ex03-matrix-multiply.cu /// /// Code from CUDA C Programming Guide, with minor modifications: /// - Use column-major order, instead of row-major. /// - Rename width => cols, height => rows. /// - Rename As => sA, Bs => sB. #include #include #include #include "util.hh" //------------------------------------------------------------------------------ // Thread block size // #define BLOCK_SIZE 16 //------------------------------------------------------------------------------ // Matrices are stored in column-major order: // M(row, col) = *(M.elements + row + col*M.stride) // typedef struct { int cols; int rows; int stride; float* elements; } Matrix; //------------------------------------------------------------------------------ // Get a matrix element, A( row, col ). // __device__ float GetElement( const Matrix A, int row, int col ) { return A.elements[row + col*A.stride]; } //------------------------------------------------------------------------------ // Set a matrix element, A( row, col ) = value. // __device__ void SetElement( Matrix A, int row, int col, float value ) { A.elements[row + col*A.stride] = value; } //------------------------------------------------------------------------------ // Get block sub-matrix A( blockRow, blockCol ). // // Get the BLOCK_SIZE x BLOCK_SIZE sub-matrix Asub of A // that is located blockCol sub-matrices to the right and // blockRow sub-matrices down from the upper-left corner of A. // __device__ Matrix GetSubMatrix( Matrix A, int blockRow, int blockCol ) { Matrix Asub; Asub.cols = BLOCK_SIZE; Asub.rows = BLOCK_SIZE; Asub.stride = A.stride; Asub.elements = &A.elements[BLOCK_SIZE * blockRow + BLOCK_SIZE * blockCol * A.stride]; return Asub; } //------------------------------------------------------------------------------ // Matrix multiplication kernel called by MatMul() // __global__ void MatMulKernel( Matrix A, Matrix B, Matrix C ) { // Block row and column int blockRow = blockIdx.x; int blockCol = blockIdx.y; // Each thread block computes one sub-matrix Csub of C Matrix Csub = GetSubMatrix( C, blockRow, blockCol ); // Each thread computes one element of Csub // by accumulating results into Cvalue float Cvalue = 0; // Thread row and column within Csub int row = threadIdx.x; int col = threadIdx.y; // Loop over all the sub-matrices of A and B that are // required to compute Csub // Multiply each pair of sub-matrices together // and accumulate the results for (int k = 0; k < (A.cols / BLOCK_SIZE); ++k) { // Get sub-matrix Asub of A Matrix Asub = GetSubMatrix( A, blockRow, k ); // Get sub-matrix Bsub of B Matrix Bsub = GetSubMatrix( B, k, blockCol ); // Shared memory used to store Asub and Bsub respectively // Note: adding +1 here makes performance worse (454 vs. 502 gflop/s). __shared__ float sA[ BLOCK_SIZE ][ BLOCK_SIZE ]; __shared__ float sB[ BLOCK_SIZE ][ BLOCK_SIZE ]; // Load Asub and Bsub from device memory to shared memory // Each thread loads one element of each sub-matrix sA[col][row] = GetElement( Asub, row, col ); sB[col][row] = GetElement( Bsub, row, col ); // Synchronize to make sure the sub-matrices are loaded // before starting the computation __syncthreads(); // Multiply Asub and Bsub together for (int e = 0; e < BLOCK_SIZE; ++e) Cvalue += sA[e][row] * sB[col][e]; // Synchronize to make sure that the preceding // computation is done before loading two new // sub-matrices of A and B in the next iteration __syncthreads(); } // Write Csub to device memory // Each thread writes one element SetElement( Csub, row, col, Cvalue ); } //------------------------------------------------------------------------------ // Matrix multiplication - Host code // Matrix dimensions are assumed to be multiples of BLOCK_SIZE void MatMul( const Matrix d_A, const Matrix d_B, Matrix d_C, cudaStream_t stream ) { // Check matrix dimensions assert( d_A.rows == d_C.rows ); assert( d_A.cols == d_B.rows ); assert( d_B.cols == d_C.cols ); // Code assumes cols and rows evenly divisible. assert( d_C.rows % BLOCK_SIZE == 0 ); assert( d_C.cols % BLOCK_SIZE == 0 ); assert( d_A.cols % BLOCK_SIZE == 0 ); // Invoke kernel dim3 dimBlock( BLOCK_SIZE, BLOCK_SIZE ); dim3 dimGrid( d_C.rows / dimBlock.x, d_C.cols / dimBlock.y ); MatMulKernel<<< dimGrid, dimBlock, 0, stream >>>( d_A, d_B, d_C ); throw_error( cudaGetLastError() ); } //============================================================================== template void test( int m, int n, int k, int verbose ) { std::vector A( m * k ); std::vector B( k * n ); std::vector C( m * n ); rand_matrix( m, k, A.data(), m ); rand_matrix( k, n, B.data(), k ); if (verbose >= 2) { print_matrix( "A", m, k, A.data(), m ); print_matrix( "B", k, n, B.data(), k ); } // Load A to device memory Matrix d_A; d_A.rows = m; d_A.cols = k; d_A.stride = m; // col-major size_t size_A = A.size() * sizeof(float); throw_error( cudaMalloc( &d_A.elements, size_A ) ); throw_error( cudaMemcpy( d_A.elements, A.data(), size_A, cudaMemcpyHostToDevice ) ); // Load B to device memory Matrix d_B; d_B.rows = k; d_B.cols = n; d_B.stride = k; // col-major d_B.elements = B.data(); size_t size_B = B.size() * sizeof(float); throw_error( cudaMalloc( &d_B.elements, size_B ) ); throw_error( cudaMemcpy( d_B.elements, B.data(), size_B, cudaMemcpyHostToDevice ) ); // Allocate C in device memory Matrix d_C; d_C.rows = m; d_C.cols = n; d_C.stride = m; // col-major d_C.elements = C.data(); size_t size_C = C.size() * sizeof(float); throw_error( cudaMalloc( &d_C.elements, size_C ) ); cudaStream_t stream = nullptr; throw_error( cudaStreamCreate( &stream ) ); // -------------------- // Multiply d_C = d_A * d_B. double time = get_wtime(); MatMul( d_A, d_B, d_C, stream ); throw_error( cudaStreamSynchronize( stream ) ); time = get_wtime() - time; // Print results. GB/s based on it reads A. double gflops = 2. * m * n * k * 1e-9 / time; printf( "m %4d, n %4d, k %4d, time %10.6f, Gflop/s %10.6f\n", m, n, k, time, gflops ); // Read C from device memory throw_error( cudaMemcpy( C.data(), d_C.elements, size_C, cudaMemcpyDeviceToHost ) ); // Free device memory throw_error( cudaFree( d_A.elements ) ); throw_error( cudaFree( d_B.elements ) ); throw_error( cudaFree( d_C.elements ) ); if (verbose >= 2) { print_matrix( "C", m, n, C.data(), m ); printf( "%% In Matlab, check that max( abs( A*B - C ), [], 'all' ) is ~ 1e-04\n" "%% since data is printed with 4 digits.\n" ); } } //------------------------------------------------------------------------------ int main( int argc, char** argv ) { int m = 16; int n = 16; int k = 16; int verbose = 0; int repeat = 5; // Rudimentary argument parsing. for (int i = 1; i < argc; ++i) { std::string arg( argv[i] ); if (arg == "-m" && i+1 < argc) { m = strtol( argv[++i], nullptr, 0 ); } else if (arg == "-n" && i+1 < argc) { n = strtol( argv[++i], nullptr, 0 ); } else if (arg == "-k" && i+1 < argc) { k = strtol( argv[++i], nullptr, 0 ); } else if (arg == "-v") { verbose += 1; } else if (arg == "-repeat" && i+1 < argc) { repeat = strtol( argv[++i], nullptr, 0 ); } else { printf( "Unknown argument: %s\n", argv[i] ); } } printf( "m %4d, n %4d, k %4d\n", m, n, k ); try { for (int i = 0; i < repeat; ++i) { test( m, n, k, verbose ); } } catch (std::exception const& ex) { fprintf( stderr, "Error: %s\n", ex.what() ); } return 0; }