/* ///////////////////////////// P /// L /// A /// S /// M /// A /////////////////////////////// */
/* ///                    PLASMA computational routine (version 2.1.0)                       ///
 * ///                    Author: Emmanuel Agullo                                            ///
 * ///                    Release Date: November, 15th 2009                                  ///
 * ///                    PLASMA is a software package provided by Univ. of Tennessee,       ///
 * ///                    Univ. of California Berkeley and Univ. of Colorado Denver          /// */

/* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
// PLASMA_sgemm - Performs one of the matrix-matrix operations 
//
//   C = alpha*op( A )*op( B ) + beta*C,
//
// where op( X ) is one of 
//
//   op( X ) = X  or op( X ) = X'
//
// alpha and beta are scalars, and A, B and C  are matrices, with op( A ) 
// an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.  

/* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
// transA   PLASMA_enum (IN)
//          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
//          = PlasmaNoTrans:   A is transposed;
//          = PlasmaTrans:     A is not transposed;
//          = PlasmaTrans: A is conjugate transposed.
//          Currently only PlasmaNoTrans is supported
//
// transB   PLASMA_enum (IN)
//          Specifies whether the matrix B is transposed, not transposed or conjugate transposed:
//          = PlasmaNoTrans:   B is transposed;
//          = PlasmaTrans:     B is not transposed;
//          = PlasmaTrans: B is conjugate transposed.
//          Currently only PlasmaNoTrans is supported
//
// M        int (IN)
//          M specifies the number of rows of the matrix op( A ) and of the matrix C. M >= 0.
//
// N        int (IN)
//          N specifies the number of columns of the matrix op( B ) and of the matrix C. N >= 0.
//
// K        int (IN)
//          K specifies the number of columns of the matrix op( A ) and the number of rows of 
//          the matrix op( B ). K >= 0.
//
// alpha    float (IN)
//          alpha specifies the scalar alpha
//
// A        float* (IN)
//          A is a LDA-by-ka matrix, where ka is K when  transA = PlasmaNoTrans,  
//          and is  M  otherwise.
//
// LDA      int (IN)
//          The leading dimension of the array A. LDA >= max(1,M).
//
// B        float* (IN)
//          B is a LDB-by-kb matrix, where kb is N when  transB = PlasmaNoTrans,  
//          and is  K  otherwise.
//
// LDB      int (IN)
//          The leading dimension of the array B. LDB >= max(1,N).
//
// beta     float (IN)
//          beta specifies the scalar beta
//
// C        float* (INOUT)
//          C is a LDC-by-N matrix.
//          On exit, the array is overwritten by the M by N matrix ( alpha*op( A )*op( B ) + beta*C )
//
// LDC      int (IN)
//          The leading dimension of the array C. LDC >= max(1,M).

/* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
//          = 0: successful exit

/* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */
#include "common.h"
#include "lapack.h"

int PLASMA_sgemm(PLASMA_enum transA, PLASMA_enum transB, int M, int N, int K, 
                 float alpha, float *A, int LDA,
		 float *B, int LDB, 
                 float beta, float *C, int LDC)
{
    int NB, MT, NT, KT;
    int nrowA, nrowB;
    int status;
    float *Abdl;
    float *Bbdl;
    float *Cbdl;
    plasma_context_t *plasma;

    plasma = plasma_context_self();
    if (plasma == NULL) {
        plasma_fatal_error("PLASMA_sgemm", "PLASMA not initialized");
        return PLASMA_ERR_NOT_INITIALIZED;
    }

    /* TODO: to adapt nrowA and nrowB depending on transA et transB, respectively. */
    nrowA = M;
    nrowB = N;

    /* Check input arguments */
    if (transA != PlasmaNoTrans && transA != PlasmaTrans && transA != PlasmaTrans) {
        plasma_error("PLASMA_sgemm", "illegal value of transA");
        return 1;
    }
    if (transB != PlasmaNoTrans && transB != PlasmaTrans && transB != PlasmaTrans) {
        plasma_error("PLASMA_sgemm", "illegal value of transB");
        return 2;
    }
    if (M < 0) {
        plasma_error("PLASMA_sgemm", "illegal value of M");
        return 3;
    }
    if (N < 0) {
        plasma_error("PLASMA_sgemm", "illegal value of N");
        return 4;
    }
    if (K < 0) {
        plasma_error("PLASMA_sgemm", "illegal value of N");
        return 5;
    }
    if (LDA < max(1, nrowA)) {
        plasma_error("PLASMA_sgemm", "illegal value of LDA");
        return 8;
    }
    if (LDB < max(1, nrowB)) {
        plasma_error("PLASMA_sgemm", "illegal value of LDB");
        return 10;
    }
    if (LDC < max(1, M)) {
        plasma_error("PLASMA_sgemm", "illegal value of LDC");
        return 13;
    }

    /* Quick return - currently NOT equivalent to LAPACK's
     * LAPACK does not have such check for DPOSV */

    if (M == 0 || N == 0 ||
        ((alpha == (float)0.0 || K == 0.0) && beta == (float)1.0))
        return PLASMA_SUCCESS;

    /* Tune NB depending on M, N & NRHS; Set NBNBSIZE */
    status = plasma_tune(PLASMA_FUNC_SGEMM, M, N, 0);
    if (status != PLASMA_SUCCESS) {
        plasma_error("PLASMA_sgemm", "plasma_tune() failed");
        return status;
    }

    /* Set MT & NT & KT */
    NB = PLASMA_NB;
    MT = (M%NB==0) ? (M/NB) : (M/NB+1);
    NT = (N%NB==0) ? (N/NB) : (N/NB+1);
    KT = (K%NB==0) ? (K/NB) : (K/NB+1);

    /* Allocate memory for matrices in block layout */
    Abdl = (float *)plasma_shared_alloc(plasma, MT*KT*PLASMA_NBNBSIZE, PlasmaRealFloat);
    Bbdl = (float *)plasma_shared_alloc(plasma, KT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
    Cbdl = (float *)plasma_shared_alloc(plasma, MT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
    if (Abdl == NULL || Bbdl == NULL || Cbdl == NULL) {
        plasma_error("PLASMA_sgemm", "plasma_shared_alloc() failed");
        plasma_shared_free(plasma, Abdl);
        plasma_shared_free(plasma, Bbdl);
        plasma_shared_free(plasma, Cbdl);
        return PLASMA_ERR_OUT_OF_RESOURCES;
    }

    PLASMA_desc descA = plasma_desc_init(
        Abdl, PlasmaRealFloat,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        M, K, 0, 0, M, K);

    PLASMA_desc descB = plasma_desc_init(
        Bbdl, PlasmaRealFloat,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        K, N, 0, 0, K, N);

    PLASMA_desc descC = plasma_desc_init(
        Cbdl, PlasmaRealFloat,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        M, N, 0, 0, M, N);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        float*, A,
        int, LDA,
        PLASMA_desc, descA);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        float*, B,
        int, LDB,
        PLASMA_desc, descB);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        float*, C,
        int, LDC,
        PLASMA_desc, descC);

    plasma_parallel_call_7(plasma_psgemm,
        PLASMA_enum, transA,
        PLASMA_enum, transB,
        float, alpha,
        PLASMA_desc, descA,
        PLASMA_desc, descB,
        float, beta,
        PLASMA_desc, descC);



    if (PLASMA_INFO == PLASMA_SUCCESS)
    {
        plasma_parallel_call_3(plasma_tile_to_lapack,
            PLASMA_desc, descC,
            float*, C,
            int, LDC);
    }
    plasma_shared_free(plasma, Abdl);
    plasma_shared_free(plasma, Bbdl);
    plasma_shared_free(plasma, Cbdl);
    return PLASMA_INFO;
}
