//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <cbe_mfc.h>
#include <spu_mfcio.h>
#include <vec_literal.h>
#include <spu_intrinsics.h>

#include "spu_blas.h"
#include "spu_dgemm.h"

//----------------------------------------------------------------------------------------------

extern int NB;
extern int DNB;
extern int spus_num;
extern int my_spu_id;
extern CallArgs spu_call_args __attribute__ ((aligned (128)));

extern double *dbuffer1;
extern double *dbuffer2;
extern double *dbuffer3;
extern double *dbuffer4;
extern double *dbuffer5;
extern double *dbuffer6;
extern double *dbuffer7;

//----------------------------------------------------------------------------------------------

void spu_dgemm()
{
    double *spuA_front;
    double *spuB_front;
    double *spuC_front;

    double *spuA_back;
    double *spuB_back;
    double *spuC_back_in;
    double *spuC_back_out;

    int tile, next_tile, prev_tile, num_tiles;
    int LDA, LDB, LDC;
    int Bm, Bn, Bk;
    int M, N, K;
    int m, n, k;

    unsigned int tag_front, tag_back;
    unsigned int ppuA, ppuB, ppuC;
    unsigned int roll_m;


    spu_recv_call_args();

    
    M = spu_call_args.gemm.M;
    N = spu_call_args.gemm.N;
    K = spu_call_args.gemm.K;
    
    LDA = spu_call_args.gemm.LDA;
    LDB = spu_call_args.gemm.LDB;
    LDC = spu_call_args.gemm.LDC;

    Bm = M/DNB;
    Bn = N/DNB;
    Bk = K/DNB;
    
    num_tiles = Bm*Bk;

    spuA_front = dbuffer1;
    spuB_front = dbuffer2;
    spuC_front = dbuffer3;

    spuA_back     = dbuffer4;
    spuB_back     = dbuffer5;
    spuC_back_in  = dbuffer6;
    spuC_back_out = dbuffer7;

    tag_front = 1;
    tag_back  = 2;

    ppuA = spu_call_args.gemm.A;
    ppuB = spu_call_args.gemm.B;
    ppuC = spu_call_args.gemm.C + (my_spu_id * M * sizeof(double));

    roll_m = (Bm == spus_num) ? 0x00 : ~0x00;


    #define receive_tile(tile, spuA, spuB, spuC, tag)           \
    {                                                           \
        int k  = ((tile) / Bm) % Bk;                            \
        int m  =  (tile) % Bm;                                  \
        int mm =  (m+k)  % Bm;                                  \
        unsigned int A, B, C;                                   \
                                                                \
        m = (m & roll_m) | (mm & ~roll_m);                      \
        A = ppuA + (m*DNB*LDA + k*DNB*DNB) * sizeof(double);    \
        B = ppuB + (k*DNB) * sizeof(double);                    \
        C = ppuC + (m*DNB) * sizeof(double);                    \
                                                                \
        spu_recv_tile_blocked_double(spuA, A, tag);             \
        mfc_get(spuB, B, DNB*sizeof(double), tag, 0, 0);        \
        mfc_get(spuC, C, DNB*sizeof(double), tag, 0, 0);        \
    }

    #define send_tile(tile, spuC, tag)                          \
    {                                                           \
        int k  = ((tile) / Bm) % Bk;                            \
        int m  =  (tile) % Bm;                                  \
        int mm =  (m+k)  % Bm;                                  \
        unsigned int C;                                         \
                                                                \
        m = (m & roll_m) | (mm & ~roll_m);                      \
        C = ppuC + (m*DNB) * sizeof(double);                    \
        mfc_put(spuC, C, DNB*sizeof(double), tag, 0, 0);        \
    }

    #define swap_buffers()                                      \
    {                                                           \
        double *p; unsigned int t;                              \
        p = spuA_front; spuA_front = spuA_back; spuA_back = p;  \
        p = spuB_front; spuB_front = spuB_back; spuB_back = p;  \
        p = spuC_front; spuC_front = spuC_back_in;              \
        spuC_back_in = spuC_back_out; spuC_back_out = p;        \
        t = tag_front; tag_front = tag_back; tag_back = t;      \
    }


    /* Receive N-1th tile */
    if (my_spu_id < num_tiles)
        receive_tile(my_spu_id, spuA_front, spuB_front, spuC_front, tag_front);

    /* Receive Nth tile */
    if (my_spu_id + spus_num < num_tiles)
        receive_tile(my_spu_id + spus_num, spuA_back, spuB_back, spuC_back_in, tag_back);

    /* Compute N-1th tile */
    spu_wait_tag(tag_front);
    spu_dgemm_tile(spuA_front, spuB_front, spuC_front);

    /* Swap buffers */
    swap_buffers();

    /* Pipelined loop */
    for (tile = my_spu_id + spus_num; tile < num_tiles - spus_num; tile += spus_num)
    {
        prev_tile = tile - spus_num;
        next_tile = tile + spus_num;

        /* Send N-1th tile */
        send_tile(prev_tile, spuC_back_out, tag_back);

        /* Receive N+1th tile */
        receive_tile(next_tile, spuA_back, spuB_back, spuC_back_in, tag_back);

        /* Compute Nth tile */
        spu_wait_tag(tag_front);
        spu_dgemm_tile(spuA_front, spuB_front, spuC_front);

        /* Swap buffers */
        swap_buffers();
    }

    /* Send Nth tile */
    if (my_spu_id < num_tiles)
        send_tile(tile - spus_num, spuC_back_out, tag_back);
    
    /* Compute N+1th tile */
    if (my_spu_id + spus_num < num_tiles)
    {
        spu_wait_tag(tag_front);
        spu_dgemm_tile(spuA_front, spuB_front, spuC_front);
    }

    /* Send N+1th tile */
    if (my_spu_id + spus_num < num_tiles)
        send_tile(tile, spuC_front, tag_back);

    /* Wait for last send completions */
    if (my_spu_id < num_tiles)
        spu_wait_tag(tag_back);
}

#undef receive_tile
#undef send_tile
#undef swap_buffers

//----------------------------------------------------------------------------------------------

void spu_dgemm_tile(double *A, double *B, double *C)
{
    int m, k;

    vector double *vecA = (vector double*)A;
    vector double *Ap = vecA;

    vector double *vecB = (vector double*)B;
    vector double *Bp;

    vector double vec_sum;
    double sum;

    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    start = spu_read_decrementer();

    for (m = 0; m < DNB; m++)
    {
        Bp = vecB;
        vec_sum = spu_splats((double)0.0);

        vec_sum = spu_madd(Ap[ 0], Bp[ 0], vec_sum);
        vec_sum = spu_madd(Ap[ 1], Bp[ 1], vec_sum);
        vec_sum = spu_madd(Ap[ 2], Bp[ 2], vec_sum);
        vec_sum = spu_madd(Ap[ 3], Bp[ 3], vec_sum);
        vec_sum = spu_madd(Ap[ 4], Bp[ 4], vec_sum);
        vec_sum = spu_madd(Ap[ 5], Bp[ 5], vec_sum);
        vec_sum = spu_madd(Ap[ 6], Bp[ 6], vec_sum);
        vec_sum = spu_madd(Ap[ 7], Bp[ 7], vec_sum);
        vec_sum = spu_madd(Ap[ 8], Bp[ 8], vec_sum);
        vec_sum = spu_madd(Ap[ 9], Bp[ 9], vec_sum);
        vec_sum = spu_madd(Ap[10], Bp[10], vec_sum);
        vec_sum = spu_madd(Ap[11], Bp[11], vec_sum);
        vec_sum = spu_madd(Ap[12], Bp[12], vec_sum);
        vec_sum = spu_madd(Ap[13], Bp[13], vec_sum);
        vec_sum = spu_madd(Ap[14], Bp[14], vec_sum);
        vec_sum = spu_madd(Ap[15], Bp[15], vec_sum);
        
        C[m] = C[m] + spu_extract(vec_sum, 0);
        C[m] = C[m] + spu_extract(vec_sum, 1);

        Ap += 16;
    }

    end = spu_read_decrementer();
    spu_log_event(start, end, 0x60C0A0);

}

//----------------------------------------------------------------------------------------------
