//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <cbe_mfc.h>
#include <spu_mfcio.h>
#include <vec_literal.h>
#include <spu_intrinsics.h>

#include "spu_blas.h"
#include "spu_strsm_1rhs_trans.h"

//----------------------------------------------------------------------------------------------

extern int NB;
extern int spus_num;
extern int my_spu_id;
extern CallArgs spu_call_args;
extern GlobalParams spu_global_params;

//----------------------------------------------------------------------------------------------

void spu_strsm_trans()
{
    extern volatile unsigned char spu_progress[];
    volatile unsigned char spu_progress_src[16] __attribute__ ((aligned (16)))
        = {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3};

    extern unsigned char mem_pool[];

    unsigned int spuT   = (unsigned int)(mem_pool + 0*NB*NB*sizeof(float));
    unsigned int spuT_  = (unsigned int)(mem_pool + 1*NB*NB*sizeof(float));
    unsigned int spuBJ  = (unsigned int)(mem_pool + 2*NB*NB*sizeof(float));
    unsigned int spuBJ_ = (unsigned int)(mem_pool + 3*NB*NB*sizeof(float));
    unsigned int spuBI  = (unsigned int)(mem_pool + 4*NB*NB*sizeof(float));
    unsigned int spuBI_ = (unsigned int)(mem_pool + 5*NB*NB*sizeof(float));

    unsigned int tagT   = 1;
    unsigned int tagT_  = 2;
    unsigned int tagBJ  = 3;
    unsigned int tagBJ_ = 4;
    unsigned int tagBI  = 5;
    unsigned int tagBI_ = 6;

    spu_recv_call_args();

    unsigned int ppuT = spu_call_args.trsm.T;
    unsigned int ppuB = spu_call_args.trsm.B;

    int N  = spu_call_args.trsm.M;
    int BB = N/NB;

    int i,  j;
    int next_j;
    int next_i;
    int spu;

    int BJ_pulled = 0;
    int BI_pulled = 0;

    //----------------------------------------------------------

    #define swap(A)\
        temp = spu##A; spu##A = spu##A##_; spu##A##_ = temp;\
        temp = tag##A; tag##A = tag##A##_; tag##A##_ = temp;

    unsigned int temp;

    #define wait(A)\
        mfc_write_tag_mask(0x01 << tag##A);\
        mfc_read_tag_status_all();

    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    i = BB-1;
    j = BB-1-my_spu_id;
    while (j < 0)
    {
        i--;
        j = j+i+1;
    }


    mfc_get((void*)spuT, ppuT + (BB*i+j)*NB*NB*sizeof(float), 16384, tagT, 0, 0);


    while (j >= 0 && i >= 0)
    {

        next_i = i;
        next_j = j;


        next_j -= spus_num;
        while (next_j < 0  && next_i >= 0)
        {
            next_i--;
            next_j = next_j+next_i+1;
        }


        if (i == j)
        {

            // Pull BJ
            if (!BJ_pulled)
            {
                if(i < BB-1)
                    while ((volatile unsigned char)(spu_progress[BB*(i+1)+j]) != 3);
                mfc_getf((void*)spuBJ, ppuB + (i)*NB*sizeof(float), 256, tagBJ, 0, 0);
            }


            if (next_j >= 0 && next_i >= 0)
            {

                // Prefetch T
                mfc_get((void*)spuT_, ppuT + (BB*next_i+next_j)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);


                // Prefetch BJ
                BJ_pulled = 0;
                if (next_i == next_j && (volatile unsigned char)(spu_progress[BB*(next_i+1)+next_j]) == 3)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_i)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }
                if (next_i != next_j && (volatile unsigned char)(spu_progress[BB*next_i+next_i]) == 3)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_i)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }


                // Prefetch BI
                BI_pulled = 0;
                if (next_i != next_j && ((volatile unsigned char)(spu_progress[BB*(next_i+1)+next_j]) == 3))
                {
                    mfc_getf((void*)spuBI_, ppuB + (next_j)*NB*sizeof(float), 256, tagBI_, 0, 0);
                    BI_pulled = 1;
                }

            }


            // Wait for T, BJ
            wait(T);
            wait(BJ);


            // Do STRSM
            start = spu_read_decrementer();
            spu_strsm_trans_tile_((float*)spuT, (float*)spuBJ);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0xA02060);


            mfc_put((void*)spuBJ, ppuB + (i)*NB*sizeof(float), 256, tagBJ, 0, 0);


            // Update progress
            for (spu = 0; spu < spus_num; spu++)
                mfc_putf(
                    (void*)&spu_progress_src[BB*j+j & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[BB*j+j]),
                    sizeof(unsigned char), tagBJ, 0, 0);


            // Swap buffers
            swap(T);
            swap(BJ);
            swap(BI);

        }
        else
        {

            // Pull BI
            if (!BI_pulled)
            {
                if(i < BB-1)
                    while ((volatile unsigned char)(spu_progress[BB*(i+1)+j]) != 3);
                mfc_getf((void*)spuBI, ppuB + (j)*NB*sizeof(float), 256, tagBI, 0, 0);
            }


            // Pull BJ
            if(!BJ_pulled)
            {
                while ((volatile unsigned char)(spu_progress[BB*i+i]) != 3);
                mfc_getf((void*)spuBJ, ppuB + (i)*NB*sizeof(float), 256, tagBJ, 0, 0);
            }


            if (next_j < BB && next_i < BB)
            {

                // Prefetch T
                mfc_get((void*)spuT_, ppuT + (BB*next_i+next_j)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);


                // Prefetch BJ
                BJ_pulled = 0;
                if (next_i == next_j && (volatile unsigned char)(spu_progress[BB*(next_i+1)+next_j]) == 3)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_i)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }
                if (next_i != next_j && (volatile unsigned char)(spu_progress[BB*next_i+next_i]) == 3)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_i)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }


                // Prefetch BI
                BI_pulled = 0;
                if (next_i != next_j && ((volatile unsigned char)(spu_progress[BB*(next_i+1)+next_j]) == 3))
                {
                    mfc_getf((void*)spuBI_, ppuB + (next_j)*NB*sizeof(float), 256, tagBI_, 0, 0);
                    BI_pulled = 1;
                }

            }

            
            // Wait for T, BJ, BI
            wait(T);
            wait(BJ);
            wait(BI);
            
            
            // Do STRSM
            start = spu_read_decrementer();
            spu_strsm_trans_tile((float*)spuT, (float*)spuBI, (float*)spuBJ);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0xF06080);


            mfc_put((void*)spuBI, ppuB + (j)*NB*sizeof(float), 256, tagBI, 0, 0); 


            // Update progress
            for (spu = 0; spu < spus_num; spu++)
                mfc_putf(
                    (void*)&spu_progress_src[BB*i+j & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[BB*i+j]),
                    sizeof(unsigned char), tagBI, 0, 0);


            // Swap buffers
            swap(T);
            swap(BJ);
            swap(BI);

        }

        i = next_i;
        j = next_j;

    }
}

//----------------------------------------------------------------------------------------------

#define BLK  64
#define VBLK 16

//----------------------------------------------------------------------------------------------

void spu_strsm_trans_tile_(float *TT, float *B)
{
    float *T  = (float*)TT;
    float *Bi = (float*)B;
    float *Bj = (float*)B;

    vector float Bj_splat[BLK];

    vector float *Tp  = (vector float*)TT;
    vector float *BIp = (vector float*)B;
    vector float *BJp = (vector float*)Bj_splat;

    int i, j;

    //----------------------------------------------------------

    #define trsm_4x4_()\
    \
        Bj[3] /= T[BLK*3+3];\
        Bi[2] -= T[BLK*3+2] * Bj[3];\
        Bi[1] -= T[BLK*3+1] * Bj[3];\
        Bi[0] -= T[BLK*3+0] * Bj[3];\
        \
        Bj[2] /= T[BLK*2+2];\
        Bi[1] -= T[BLK*2+1] * Bj[2];\
        Bi[0] -= T[BLK*2+0] * Bj[2];\
        \
        Bj[1] /= T[BLK*1+1];\
        Bi[0] -= T[BLK*1+0] * Bj[1];\
        \
        Bj[0] /= T[BLK*0+0];\
        \
        BJp[3] = spu_splats(Bj[3]);\
        BJp[2] = spu_splats(Bj[2]);\
        BJp[1] = spu_splats(Bj[1]);\
        BJp[0] = spu_splats(Bj[0]);\

    //----------------------------------------------------------

    #define trsm_4x4()\
    \
        *BIp = spu_nmsub(Tp[VBLK*3], BJp[3], *BIp);\
        *BIp = spu_nmsub(Tp[VBLK*2], BJp[2], *BIp);\
        *BIp = spu_nmsub(Tp[VBLK*1], BJp[1], *BIp);\
        *BIp = spu_nmsub(Tp[VBLK*0], BJp[0], *BIp);\

    //----------------------------------------------------------

    T += BLK * BLK;
    T -= BLK*4;
    T += BLK;
    T -= 4;

    Bj += BLK;
    Bi += BLK;

    Bj -= 4;
    Bi -= 4;

    Tp += VBLK * BLK;
    Tp -= VBLK*4;
    Tp += VBLK;
    Tp -= 1;

    BJp += BLK;
    BIp += VBLK;

    BJp -= 4;
    BIp -= 1;

    for (j = 0; j < VBLK; j++)
    {
        for (i = 0; i < j; i++)
        {
            trsm_4x4();

            T  -= BLK*4;
            Bj -= 4;

            Tp  -= VBLK*4;
            BJp -= 4;
        }
        trsm_4x4_();

        T  += j*BLK*4;
        Bj += j*4;

        T  -= 4;
        Bi -= 4;

        Tp  += j*VBLK*4;
        BJp += j*4;

        Tp  -= 1;
        BIp -= 1;
    }

    #undef trsm_4x4_
    #undef trsm_4x4

}

//----------------------------------------------------------------------------------------------

void spu_strsm_trans_tile(float *TT, float *BI, float *BJ)
{
    float *T  = (float*)TT;
    float *Bi = (float*)BI;
    float *Bj = (float*)BJ;

    vector float Bj_splat[BLK];

    vector float *Tp  = (vector float*)TT;
    vector float *BIp = (vector float*)BI;
    vector float *BJp = (vector float*)Bj_splat;

    vector float b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15;

    int i, j;

    //----------------------------------------------------------

    #define trsm_splat_4(OFFB)\
    \
        Bj_splat[OFFB+0] = spu_splats(BJ[OFFB+0]);\
        Bj_splat[OFFB+1] = spu_splats(BJ[OFFB+1]);\
        Bj_splat[OFFB+2] = spu_splats(BJ[OFFB+2]);\
        Bj_splat[OFFB+3] = spu_splats(BJ[OFFB+3]);\

    #define trsm_splat_16(OFFB)\
    \
        trsm_splat_4(OFFB+ 0);\
        trsm_splat_4(OFFB+ 4);\
        trsm_splat_4(OFFB+ 8);\
        trsm_splat_4(OFFB+12);\

    #define trsm_splat_64()\
    \
        trsm_splat_16( 0);\
        trsm_splat_16(16);\
        trsm_splat_16(32);\
        trsm_splat_16(48);\

    //----------------------------------------------------------

    #define trsm_4x4_load()\
    \
        b0  = spu_nmsub(Tp[ -0], *BJp, BIp[ -0]);\
        b1  = spu_nmsub(Tp[ -1], *BJp, BIp[ -1]);\
        b2  = spu_nmsub(Tp[ -2], *BJp, BIp[ -2]);\
        b3  = spu_nmsub(Tp[ -3], *BJp, BIp[ -3]);\
        \
        b4  = spu_nmsub(Tp[ -4], *BJp, BIp[ -4]);\
        b5  = spu_nmsub(Tp[ -5], *BJp, BIp[ -5]);\
        b6  = spu_nmsub(Tp[ -6], *BJp, BIp[ -6]);\
        b7  = spu_nmsub(Tp[ -7], *BJp, BIp[ -7]);\
        \
        b8  = spu_nmsub(Tp[ -8], *BJp, BIp[ -8]);\
        b9  = spu_nmsub(Tp[ -9], *BJp, BIp[ -9]);\
        b10 = spu_nmsub(Tp[-10], *BJp, BIp[-10]);\
        b11 = spu_nmsub(Tp[-11], *BJp, BIp[-11]);\
        \
        b12 = spu_nmsub(Tp[-12], *BJp, BIp[-12]);\
        b13 = spu_nmsub(Tp[-13], *BJp, BIp[-13]);\
        b14 = spu_nmsub(Tp[-14], *BJp, BIp[-14]);\
        b15 = spu_nmsub(Tp[-15], *BJp, BIp[-15]);\

    //----------------------------------------------------------

    #define trsm_4x4(OFFT, OFFBJ)\
    \
        b0  = spu_nmsub(Tp[OFFT -0], BJp[OFFBJ],  b0);\
        b1  = spu_nmsub(Tp[OFFT -1], BJp[OFFBJ],  b1);\
        b2  = spu_nmsub(Tp[OFFT -2], BJp[OFFBJ],  b2);\
        b3  = spu_nmsub(Tp[OFFT -3], BJp[OFFBJ],  b3);\
        \
        b4  = spu_nmsub(Tp[OFFT -4], BJp[OFFBJ],  b4);\
        b5  = spu_nmsub(Tp[OFFT -5], BJp[OFFBJ],  b5);\
        b6  = spu_nmsub(Tp[OFFT -6], BJp[OFFBJ],  b6);\
        b7  = spu_nmsub(Tp[OFFT -7], BJp[OFFBJ],  b7);\
        \
        b8  = spu_nmsub(Tp[OFFT -8], BJp[OFFBJ],  b8);\
        b9  = spu_nmsub(Tp[OFFT -9], BJp[OFFBJ],  b9);\
        b10 = spu_nmsub(Tp[OFFT-10], BJp[OFFBJ], b10);\
        b11 = spu_nmsub(Tp[OFFT-11], BJp[OFFBJ], b11);\
        \
        b12 = spu_nmsub(Tp[OFFT-12], BJp[OFFBJ], b12);\
        b13 = spu_nmsub(Tp[OFFT-13], BJp[OFFBJ], b13);\
        b14 = spu_nmsub(Tp[OFFT-14], BJp[OFFBJ], b14);\
        b15 = spu_nmsub(Tp[OFFT-15], BJp[OFFBJ], b15);\

    //----------------------------------------------------------

    #define trsm_4x4_store(OFFT, OFFBJ)\
    \
        BIp[ -0] = spu_nmsub(Tp[OFFT -0], BJp[OFFBJ],  b0);\
        BIp[ -1] = spu_nmsub(Tp[OFFT -1], BJp[OFFBJ],  b1);\
        BIp[ -2] = spu_nmsub(Tp[OFFT -2], BJp[OFFBJ],  b2);\
        BIp[ -3] = spu_nmsub(Tp[OFFT -3], BJp[OFFBJ],  b3);\
        \
        BIp[ -4] = spu_nmsub(Tp[OFFT -4], BJp[OFFBJ],  b4);\
        BIp[ -5] = spu_nmsub(Tp[OFFT -5], BJp[OFFBJ],  b5);\
        BIp[ -6] = spu_nmsub(Tp[OFFT -6], BJp[OFFBJ],  b6);\
        BIp[ -7] = spu_nmsub(Tp[OFFT -7], BJp[OFFBJ],  b7);\
        \
        BIp[ -8] = spu_nmsub(Tp[OFFT -8], BJp[OFFBJ],  b8);\
        BIp[ -9] = spu_nmsub(Tp[OFFT -9], BJp[OFFBJ],  b9);\
        BIp[-10] = spu_nmsub(Tp[OFFT-10], BJp[OFFBJ], b10);\
        BIp[-11] = spu_nmsub(Tp[OFFT-11], BJp[OFFBJ], b11);\
        \
        BIp[-12] = spu_nmsub(Tp[OFFT-12], BJp[OFFBJ], b12);\
        BIp[-13] = spu_nmsub(Tp[OFFT-13], BJp[OFFBJ], b13);\
        BIp[-14] = spu_nmsub(Tp[OFFT-14], BJp[OFFBJ], b14);\
        BIp[-15] = spu_nmsub(Tp[OFFT-15], BJp[OFFBJ], b15);\

    //----------------------------------------------------------

    trsm_splat_64();

    Tp += VBLK*BLK;
    Tp -= 1;

    BJp +=  BLK;
    BJp -= 1;

    BIp += VBLK;
    BIp -= 1;

    trsm_4x4_load();
    
    trsm_4x4(-VBLK*1, -1);
    trsm_4x4(-VBLK*2, -2);
    trsm_4x4(-VBLK*3, -3);

    Tp -= BLK;
    BJp -= 4;
 
    for (i = 1; i < VBLK-1; i++)
    {
        trsm_4x4(-VBLK*0, -0);
        trsm_4x4(-VBLK*1, -1);
        trsm_4x4(-VBLK*2, -2);
        trsm_4x4(-VBLK*3, -3);

        Tp -= BLK;
        BJp -= 4;
    }

    trsm_4x4(-VBLK*0, -0);
    trsm_4x4(-VBLK*1, -1);
    trsm_4x4(-VBLK*2, -2);

    trsm_4x4_store(-VBLK*3, -3);
}

//----------------------------------------------------------------------------------------------
