//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <cbe_mfc.h>
#include <spu_mfcio.h>
#include <vec_literal.h>
#include <spu_intrinsics.h>

#include "spu_blas.h"
#include "spu_strsm_1rhs_notrans.h"

//----------------------------------------------------------------------------------------------

extern int NB;
extern int spus_num;
extern int my_spu_id;
extern CallArgs spu_call_args;
extern GlobalParams spu_global_params;

//----------------------------------------------------------------------------------------------

void spu_strsm_no_trans()
{
    extern volatile unsigned char spu_progress[];
    volatile unsigned char spu_progress_src[16] __attribute__ ((aligned (16)))
        = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};

    extern unsigned char mem_pool[];

    unsigned int spuT   = (unsigned int)(mem_pool + 0*NB*NB*sizeof(float));
    unsigned int spuT_  = (unsigned int)(mem_pool + 1*NB*NB*sizeof(float));
    unsigned int spuBJ  = (unsigned int)(mem_pool + 2*NB*NB*sizeof(float));
    unsigned int spuBJ_ = (unsigned int)(mem_pool + 3*NB*NB*sizeof(float));
    unsigned int spuBI  = (unsigned int)(mem_pool + 4*NB*NB*sizeof(float));
    unsigned int spuBI_ = (unsigned int)(mem_pool + 5*NB*NB*sizeof(float));

    unsigned int tagT   = 1;
    unsigned int tagT_  = 2;
    unsigned int tagBJ  = 3;
    unsigned int tagBJ_ = 4;
    unsigned int tagBI  = 5;
    unsigned int tagBI_ = 6;

    spu_recv_call_args();

    unsigned int ppuT = spu_call_args.trsm.T;
    unsigned int ppuB = spu_call_args.trsm.B;

    int N  = spu_call_args.trsm.M;
    int BB = N/NB;
    
    int i, j;
    int next_j;
    int next_i;
    int spu;

    int BJ_pulled = 0;
    int BI_pulled = 0;

    //----------------------------------------------------------

    #define swap(A)\
        temp = spu##A; spu##A = spu##A##_; spu##A##_ = temp;\
        temp = tag##A; tag##A = tag##A##_; tag##A##_ = temp;

    unsigned int temp;

    #define wait(A)\
        mfc_write_tag_mask(0x01 << tag##A);\
        mfc_read_tag_status_all();

    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    j = 0;
    i = my_spu_id;
    while (i >= BB)
    {
        j++;
        i = i-BB+j;
    }


    mfc_get((void*)spuT, ppuT + (BB*i+j)*NB*NB*sizeof(float), 16384, tagT, 0, 0);


    while (j < BB && i < BB)
    {

        next_i = i;
        next_j = j;


        next_i += spus_num;
        while (next_i >= BB && next_j < BB)
        {
            next_j++;
            next_i = next_i-BB+next_j;
        }


        if (i == j)
        {

            // Pull BJ
            if (!BJ_pulled)
            {
                if(j > 0)
                    while ((volatile unsigned char)(spu_progress[BB*i+(j-1)]) != 2);
                mfc_getf((void*)spuBJ, ppuB + (j)*NB*sizeof(float), 256, tagBJ, 0, 0);
            }


            if (next_j < BB && next_i < BB)
            {

                // Prefetch T
                mfc_get(
                    (void*)spuT_, ppuT + (BB*next_i+next_j)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);


                // Prefetch BJ
                BJ_pulled = 0;
                if (next_i == next_j && (volatile unsigned char)(spu_progress[BB*next_i+(next_j-1)]) == 2)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_j)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }
                if (next_i != next_j && (volatile unsigned char)(spu_progress[BB*next_j+next_j]) == 2)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_j)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }


                // Prefetch BI
                BI_pulled = 0;
                if (next_i != next_j && ((volatile unsigned char)(spu_progress[BB*next_i+(next_j-1)]) == 2))
                {
                    mfc_getf((void*)spuBI_, ppuB + (next_i)*NB*sizeof(float), 256, tagBI_, 0, 0);
                    BI_pulled = 1;
                }

            }


            // Wait for T, BJ
            wait(T);
            wait(BJ);


            // Do STRSM
            start = spu_read_decrementer();
            spu_strsm_no_trans_tile_((float*)spuT, (float*)spuBJ);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0xA02060);


            mfc_put((void*)spuBJ, ppuB + (j)*NB*sizeof(float), 256, tagBJ, 0, 0);


            // Update progress
            for (spu = 0; spu < spus_num; spu++)
                mfc_putf(
                    (void*)&spu_progress_src[BB*j+j & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[BB*j+j]),
                    sizeof(unsigned char), tagBJ, 0, 0);


            // Swap buffers
            swap(T);
            swap(BJ);
            swap(BI);

        }
        else
        {

            // Pull BI
            if (!BI_pulled)
            {
                if(j > 0)
                    while ((volatile unsigned char)(spu_progress[BB*i+(j-1)]) != 2);
                mfc_getf((void*)spuBI, ppuB + (i)*NB*sizeof(float), 256, tagBI, 0, 0);
            }


            // Pull BJ
            if(!BJ_pulled)
            {
                while ((volatile unsigned char)(spu_progress[BB*j+j]) != 2);
                mfc_getf((void*)spuBJ, ppuB + (j)*NB*sizeof(float), 256, tagBJ, 0, 0);
            }


            if (next_j < BB && next_i < BB)
            {

                // Prefetch T
                mfc_get((void*)spuT_, ppuT + (BB*next_i+next_j)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);


                // Prefetch BJ
                BJ_pulled = 0;
                if (next_i == next_j && (volatile unsigned char)(spu_progress[BB*next_i+(next_j-1)]) == 2)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_j)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }
                if (next_i != next_j && (volatile unsigned char)(spu_progress[BB*next_j+next_j]) == 2)
                {
                    mfc_getf((void*)spuBJ_, ppuB + (next_j)*NB*sizeof(float), 256, tagBJ_, 0, 0);
                    BJ_pulled = 1;
                }


                // Prefetch BI
                BI_pulled = 0;
                if (next_i != next_j && ((volatile unsigned char)(spu_progress[BB*next_i+(next_j-1)]) == 2))
                {
                    mfc_getf((void*)spuBI_, ppuB + (next_i)*NB*sizeof(float), 256, tagBI_, 0, 0);
                    BI_pulled = 1;
                }

            }


            // Wait for T, BJ, BI
            wait(T);
            wait(BJ);
            wait(BI);


            // Do STRSM
            start = spu_read_decrementer();
            spu_strsm_no_trans_tile((float*)spuT, (float*)spuBI, (float*)spuBJ);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0xF06080);


            mfc_put((void*)spuBI, ppuB + (i)*NB*sizeof(float), 256, tagBI, 0, 0);


            // Update progress
            for (spu = 0; spu < spus_num; spu++)
                mfc_putf(
                    (void*)&spu_progress_src[BB*i+j & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[BB*i+j]),
                    sizeof(unsigned char), tagBI, 0, 0);


            // Swap buffers
            swap(T);
            swap(BJ);
            swap(BI);

        }

        i = next_i;
        j = next_j;

    }
}

//----------------------------------------------------------------------------------------------

#define BLK  64
#define VBLK 16

//----------------------------------------------------------------------------------------------

void spu_strsm_no_trans_tile_(float *TT, float *B)
{
    float *T  = (float*)TT;
    float *Bi = (float*)B;
    float *Bj = (float*)B;

    vector float *Tp  = (vector float*)TT;
    vector float *BIp = (vector float*)B;
    vector float *BJp = (vector float*)B;
    
    vector float b0, b1, b2, b3;

    vector unsigned char shufflehi = VEC_LITERAL(vector unsigned char,
        0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13,
        0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17);

    vector unsigned char shufflelo = VEC_LITERAL(vector unsigned char,
        0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B,
        0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F);

    vector float aibj;
    vector float ckdl;
    vector float emfn;
    vector float gohp;

    int i, j;

    //----------------------------------------------------------

    #define shuffle_4x1(abcd, efgh, ijkl, mnop)\
    \
        aibj = spu_shuffle(abcd, ijkl, shufflehi);\
        ckdl = spu_shuffle(abcd, ijkl, shufflelo);\
        emfn = spu_shuffle(efgh, mnop, shufflehi);\
        gohp = spu_shuffle(efgh, mnop, shufflelo);\
        \
        abcd = spu_shuffle(aibj, emfn, shufflehi);\
        efgh = spu_shuffle(aibj, emfn, shufflelo);\
        ijkl = spu_shuffle(ckdl, gohp, shufflehi);\
        mnop = spu_shuffle(ckdl, gohp, shufflelo);\

    //----------------------------------------------------------

    #define trsm_4x4_()\
    \
        Bj[0] /= T[BLK*0+0];\
        Bi[1] -= T[BLK*1+0] * Bj[0];\
        Bi[2] -= T[BLK*2+0] * Bj[0];\
        Bi[3] -= T[BLK*3+0] * Bj[0];\
        \
        Bj[1] /= T[BLK*1+1];\
        Bi[2] -= T[BLK*2+1] * Bj[1];\
        Bi[3] -= T[BLK*3+1] * Bj[1];\
        \
        Bj[2] /= T[BLK*2+2];\
        Bi[3] -= T[BLK*3+2] * Bj[2];\
        \
        Bj[3] /= T[BLK*3+3];\

    //----------------------------------------------------------

    #define trsm_4x4_load()\
    \
        b0 = spu_mul(Tp[VBLK*0], *BJp);\
        b1 = spu_mul(Tp[VBLK*1], *BJp);\
        b2 = spu_mul(Tp[VBLK*2], *BJp);\
        b3 = spu_mul(Tp[VBLK*3], *BJp);\

    //----------------------------------------------------------

    #define trsm_4x4()\
    \
        b0 = spu_madd(Tp[VBLK*0], *BJp, b0);\
        b1 = spu_madd(Tp[VBLK*1], *BJp, b1);\
        b2 = spu_madd(Tp[VBLK*2], *BJp, b2);\
        b3 = spu_madd(Tp[VBLK*3], *BJp, b3);\

    //---------------------------------------------------------- 

    #define trsm_4x4_store()\
    \
        shuffle_4x1(b0, b1, b2, b3);\
        *BIp = spu_sub(*BIp, spu_add(b3, spu_add(b2, spu_add(b0, b1))));\
        
    //----------------------------------------------------------

    trsm_4x4_();

    T  += BLK*4;
    Bi += 4;

    Tp  += VBLK*4;
    BIp += 1;

    for (j = 1; j < VBLK; j++)
    {
        trsm_4x4_load();

        T  += 4;
        Bj += 4;

        Tp  += 1;
        BJp += 1;

        for (i = 1; i < j; i++)
        {
            trsm_4x4();

            T  += 4;
            Bj += 4;

            Tp  += 1;
            BJp += 1;
        }
        trsm_4x4_store();

        trsm_4x4_();

        T  -= 4*j;
        Bj -= 4*j;

        T  += BLK*4;
        Bi += 4;

        Tp  -= j;
        BJp -= j;

        Tp  += VBLK*4;
        BIp += 1;
    }

    #undef shuffle_4x1
    #undef trsm_4x4_
    #undef trsm_4x4_load
    #undef trsm_4x4
    #undef trsm_4x4_store
}

//----------------------------------------------------------------------------------------------

void spu_strsm_no_trans_tile(float *TT, float *BI, float *BJ)
{
    float *T  = (float*)TT;
    float *Bi = (float*)BI;
    float *Bj = (float*)BJ;

    vector float *Tp  = (vector float*)TT;
    vector float *BIp = (vector float*)BI;
    vector float *BJp = (vector float*)BJ;
    
    vector float b0, b1, b2, b3;

    vector unsigned char shufflehi = VEC_LITERAL(vector unsigned char,
        0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13,
        0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17);

    vector unsigned char shufflelo = VEC_LITERAL(vector unsigned char,
        0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B,
        0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F);

    vector float aibj;
    vector float ckdl;
    vector float emfn;
    vector float gohp;

    int i, j;

    //----------------------------------------------------------

    #define shuffle_4x1(abcd, efgh, ijkl, mnop)\
    \
        aibj = spu_shuffle(abcd, ijkl, shufflehi);\
        ckdl = spu_shuffle(abcd, ijkl, shufflelo);\
        emfn = spu_shuffle(efgh, mnop, shufflehi);\
        gohp = spu_shuffle(efgh, mnop, shufflelo);\
        \
        abcd = spu_shuffle(aibj, emfn, shufflehi);\
        efgh = spu_shuffle(aibj, emfn, shufflelo);\
        ijkl = spu_shuffle(ckdl, gohp, shufflehi);\
        mnop = spu_shuffle(ckdl, gohp, shufflelo);\

    //----------------------------------------------------------

    #define trsm_4x4_load()\
    \
        b0 = spu_mul(Tp[VBLK*0], *BJp);\
        b1 = spu_mul(Tp[VBLK*1], *BJp);\
        b2 = spu_mul(Tp[VBLK*2], *BJp);\
        b3 = spu_mul(Tp[VBLK*3], *BJp);\

    //----------------------------------------------------------

    #define trsm_4x4(OFFT, OFFBJ)\
    \
        b0 = spu_madd(Tp[VBLK*0+OFFT], BJp[OFFBJ], b0);\
        b1 = spu_madd(Tp[VBLK*1+OFFT], BJp[OFFBJ], b1);\
        b2 = spu_madd(Tp[VBLK*2+OFFT], BJp[OFFBJ], b2);\
        b3 = spu_madd(Tp[VBLK*3+OFFT], BJp[OFFBJ], b3);\

    //---------------------------------------------------------- 

    #define trsm_4x4_store()\
    \
        shuffle_4x1(b0, b1, b2, b3);\
        *BIp = spu_sub(*BIp, spu_add(b3, spu_add(b2, spu_add(b0, b1))));\
        
    //----------------------------------------------------------

    #define trsm_4x64()\
    \
        trsm_4x4_load();\
        \
        trsm_4x4( 1,  1);\
        trsm_4x4( 2,  2);\
        trsm_4x4( 3,  3);\
        trsm_4x4( 4,  4);\
        trsm_4x4( 5,  5);\
        trsm_4x4( 6,  6);\
        trsm_4x4( 7,  7);\
        trsm_4x4( 8,  8);\
        trsm_4x4( 9,  9);\
        trsm_4x4(10, 10);\
        trsm_4x4(11, 11);\
        trsm_4x4(12, 12);\
        trsm_4x4(13, 13);\
        trsm_4x4(14, 14);\
        trsm_4x4(15, 15);\
        \
        trsm_4x4_store();\
    
    //----------------------------------------------------------

    for (j = 0; j < VBLK; j++)
    {
        trsm_4x64();

        Tp  += BLK;
        BIp += 1;
    }

    #undef shuffle_4x1
    #undef trsm_4x4_load
    #undef trsm_4x4
    #undef trsm_4x4_store
    #undef trsm_4x64
}

//----------------------------------------------------------------------------------------------
