// Simple CUDA batch_dot with test main() driver.
// Only very lightly tested.

#include <magma_v2.h>
#include <magma_lapack.h>
#include <stdio.h>


// -----------------------------------------------------------------------------
// CUDA GPU kernel for batch_dot()
template <typename value_type>
__global__
void batch_dot_kernel(
    int m, int n,
    value_type const* A, int lda,
    value_type const* B, int ldb,
    value_type* c )
{
    int tid = blockIdx.x*blockDim.x + threadIdx.x;
    if (tid < m) {
        // adjust pointers to A[tid, 0], B[tid, 0], c[tid]
        A += tid;
        B += tid;
        c += tid;
        
        // accumulate c[tid] = A[tid,:]^T B[tid,:]
        c[0] = 0;
        for (int j = 0; j < n; ++j) {
            c[0] += A[j*lda] * B[j*ldb];
        }
    }
}


// -----------------------------------------------------------------------------
// Computes dot product of each row of A with corresponding row of B;
// sets c[i] = dot( A[i,:], B[i,:] ).
// A is an m-by-n matrix in an lda-by-n array.
// B is an m-by-n matrix in an ldb-by-n array.
// c is a vector of length m.
// (In batch terminology, m is the batch_count. This does m independent dots.)
template <typename value_type>
void batch_dot(
    int m, int n,
    value_type const* A, int lda,
    value_type const* B, int ldb,
    value_type* c,
    magma_queue_t queue )
{
    assert( m >= 0 );
    assert( n >= 0 );
    if (m == 0)
        return;
    
    int nb = 32; // block size
    dim3 blocks( magma_ceildiv( m, nb ) );
    dim3 threads( nb );
    cudaStream_t stream = magma_queue_get_cuda_stream( queue );
    batch_dot_kernel<<< blocks, threads, 0, stream >>>
        ( m, n, A, lda, B, ldb, c );
}


// -----------------------------------------------------------------------------
int main()
{
    double *A, *B, *c;
    double *dA, *dB, *dc;
    int m = 100, n = 20;
    int lda = magma_roundup( m, 32 );  // round up to multiple of 32
    int ldb = lda;
    
    magma_init();
    
    magma_dmalloc_cpu( &A, lda*n );
    magma_dmalloc_cpu( &B, ldb*n );
    magma_dmalloc_cpu( &c, m );
    assert( A != NULL );
    assert( B != NULL );
    assert( c != NULL );
    
    magma_dmalloc( &dA, lda*n );
    magma_dmalloc( &dB, ldb*n );
    magma_dmalloc( &dc, m );
    assert( dA != NULL );
    assert( dB != NULL );
    assert( dc != NULL );
    
    int dev = 0;
    magma_queue_t queue;
    magma_queue_create( dev, &queue );
    
    int idist = 3; // normal
    int iseed[4] = { 0, 1, 2, 3 };
    int sizeA = lda*n;
    int sizeB = ldb*n;
    lapackf77_dlarnv( &idist, iseed, &sizeA, A );
    lapackf77_dlarnv( &idist, iseed, &sizeB, B );
    
    printf( "A = " );  magma_dprint( m, n, A, lda );
    printf( "B = " );  magma_dprint( m, n, B, ldb );
    magma_dsetmatrix( m, n, A, lda, dA, lda, queue );
    magma_dsetmatrix( m, n, B, ldb, dB, ldb, queue );
    
    // batch_dot dots each row of A with a row of B:
    // for (i = 0; i < m; ++i)
    //     c[i] = dot( A[i,:], B[i,:] )
    batch_dot( m, n, dA, lda, dB, ldb, dc, queue );
    
    magma_dgetvector( m, dc, 1, c, 1, queue );
    printf( "c = " );  magma_dprint( m, 1, c, m );
    
    magma_free_cpu( A );
    magma_free_cpu( B );
    magma_free_cpu( c );
    
    magma_free( dA );
    magma_free( dB );
    magma_free( dc );
    
    magma_queue_destroy( queue );
    magma_finalize();
    return 0;
}
