//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <cbe_mfc.h>
#include <spu_mfcio.h>
#include <vec_literal.h>
#include <spu_intrinsics.h>

#include "spu_blas.h"
#include "spu_spotrf.h"

#include "spu_ssyrk_tile.h"
#include "spu_sgemm_tile.h"
#include "spu_strsm_tile.h"
#include "spu_spotrf_tile.h"

//----------------------------------------------------------------------------------------------

extern int NB;
extern int spus_num;
extern int my_spu_id;
extern CallArgs spu_call_args;
extern GlobalParams spu_global_params;

//----------------------------------------------------------------------------------------------

void spu_spotrf()
{
    extern volatile unsigned char spu_progress[];
    volatile unsigned char spu_progress_src[16] __attribute__ ((aligned (16)))
        = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};

    extern unsigned char mem_pool[];

    unsigned int spuA  = (unsigned int)(mem_pool + 0*NB*NB*sizeof(float));
    unsigned int spuA_ = (unsigned int)(mem_pool + 1*NB*NB*sizeof(float));
    unsigned int spuB  = (unsigned int)(mem_pool + 2*NB*NB*sizeof(float));
    unsigned int spuB_ = (unsigned int)(mem_pool + 3*NB*NB*sizeof(float));
    unsigned int spuC  = (unsigned int)(mem_pool + 4*NB*NB*sizeof(float));
    unsigned int spuC_ = (unsigned int)(mem_pool + 5*NB*NB*sizeof(float));
    unsigned int spuT  = (unsigned int)(mem_pool + 6*NB*NB*sizeof(float));
    unsigned int spuT_ = (unsigned int)(mem_pool + 7*NB*NB*sizeof(float));

    unsigned int tagA  = 1;
    unsigned int tagA_ = 2;
    unsigned int tagB  = 3;
    unsigned int tagB_ = 4;
    unsigned int tagC  = 5;
    unsigned int tagC_ = 6;
    unsigned int tagT  = 7;
    unsigned int tagT_ = 8;

    spu_recv_call_args();

    unsigned int ppuA = spu_call_args.potrf.A;

    int N = spu_call_args.potrf.N;
    int B = N/NB;

    int step;
    int m, n;
    int next_m;
    int next_step;

    int spu;
    int shift;

    int A_pulled = 0;
    int B_pulled = 0;
    int T_pulled = 0;
    int C_pulled = 0;

    //----------------------------------------------------------

    #define swap(A)\
        temp = spu##A; spu##A = spu##A##_; spu##A##_ = temp;\
        temp = tag##A; tag##A = tag##A##_; tag##A##_ = temp;

    unsigned int temp;

    #define wait(A)\
        mfc_write_tag_mask(0x01 << tag##A);\
        mfc_read_tag_status_all();

    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    shift = ((B*(B+1))/2)%spus_num;


    step = 0;
    m = ((my_spu_id+shift)%spus_num);
    while (m >= B)
    {
        step++;
        m = m-B+step;
    }


    while (step < B && m < B)
    {

        next_m = m;
        next_step = step;


        next_m += spus_num;
        while (next_m >= B && next_step < B)
        {
            next_step++;
            next_m = next_m-B+next_step;
        }


        if (m == step)
        {

            // Pull T
            if (!T_pulled)
            {
                mfc_get((void*)spuT, ppuA + (B*step+step)*NB*NB*sizeof(float), 16384, tagT, 0, 0);
            }


            // Prefetch C
            C_pulled = 0;
            if (next_step < B && next_step != next_m)
            {
                mfc_getf((void*)spuC_, ppuA + (B*next_m+next_step)*NB*NB*sizeof(float), 16384, tagC_, 0, 0);
                C_pulled = 1;
                swap(C);
            }


            wait(T);


            for (n = 0; n < step; n++)
            {

                // Pull A
                if (!A_pulled)
                {
                    while ((volatile unsigned char)(spu_progress[B*m+n]) == 0);
                    mfc_get((void*)spuA, ppuA + (B*m+n)*NB*NB*sizeof(float), 16384, tagA, 0, 0);
                }


                // Prefetch A
                if (n+1 < step)
                {
                    A_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*m+(n+1)]) != 0)
                    {
                        mfc_get((void*)spuA_, ppuA + (B*m+(n+1))*NB*NB*sizeof(float), 16384, tagA_, 0, 0);
                        A_pulled = 1;
                    }
                }


                wait(A);


                // Do SSYRK
                start = spu_read_decrementer();
                spu_ssyrk_tile((float*)spuA, (float*)spuT);
                end = spu_read_decrementer();
                spu_log_event(start, end, 0x60A0C0);


                swap(A);

            }


            // Prefetch A
            if (next_step < B)
            {
                A_pulled = 0;
                if ((volatile unsigned char)(spu_progress[B*next_step]) != 0)
                {
                    mfc_get((void*)spuA, ppuA + (B*next_step)*NB*NB*sizeof(float), 16384, tagA, 0, 0);
                    A_pulled = 1;
                }
            }


            // Prefetch B
            if (next_step < B && next_step != next_m)
            {
                B_pulled = 0;
                if ((volatile unsigned char)(spu_progress[B*next_m]) != 0)
                {
                    mfc_get((void*)spuB, ppuA + (B*next_m)*NB*NB*sizeof(float), 16384, tagB, 0, 0);
                    B_pulled = 1;
                }
            }


            // Prefetch T
            if (next_step < B)
            {
                if (next_step == next_m)
                {
                    mfc_getf((void*)spuT_, ppuA + (B*next_step+next_step)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);
                    T_pulled = 1;
                }
                else
                {
                    T_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*next_step+next_step]) == 1)
                    {
                        mfc_getf((void*)spuT_, ppuA + (B*next_step+next_step)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);
                        T_pulled = 1;
                    }
                }
            }


            // Do SPOTRF
            start = spu_read_decrementer();
            spu_spotrf_tile((float*)spuT);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0x2060A0);


            // Push T to mem
            mfc_put((void*)spuT, ppuA + (B*step+step)*NB*NB*sizeof(float), 16384, tagT, 0, 0);


            // Update progress - SPOTRF
            for (spu = 0; spu < spus_num; spu++)
                mfc_putf(
                    (void*)&spu_progress_src[B*step+step & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[B*step+step]),
                    sizeof(unsigned char), tagT, 0, 0);


            swap(T);

        }
        else
        {

            // Pull C
            if (!C_pulled)
            {
                mfc_get((void*)spuC, ppuA + (B*m+step)*NB*NB*sizeof(float), 16384, tagC, 0, 0);
            }


            // Prefetch C
            C_pulled = 0;
            if (next_step < B && next_step != next_m)
            {
                mfc_getf((void*)spuC_, ppuA + (B*next_m+next_step)*NB*NB*sizeof(float), 16384, tagC_, 0, 0);
                C_pulled = 1;

            }


            wait(C);


            for (n = 0; n < step; n++)
            {

                // Pull A
                if (!A_pulled)
                {
                    while ((volatile unsigned char)(spu_progress[B*step+n]) == 0);
                    mfc_get((void*)spuA, ppuA + (B*step+n)*NB*NB*sizeof(float), 16384, tagA, 0, 0);
                }


                // Pull B
                if (!B_pulled)
                {
                    // Pull B
                    while ((volatile unsigned char)(spu_progress[B*m+n]) == 0);
                    mfc_get((void*)spuB, ppuA + (B*m+n)*NB*NB*sizeof(float), 16384, tagB, 0, 0);
                }


                // Prefetch A
                if (n+1 < step)
                {
                    A_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*step+(n+1)]) != 0)
                    {
                        mfc_get((void*)spuA_, ppuA + (B*step+(n+1))*NB*NB*sizeof(float), 16384, tagA_, 0, 0);
                        A_pulled = 1;
                    }
                }


                // Prefetch B
                if (n+1 < step)
                {
                    B_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*m+(n+1)]) != 0)
                    {
                        mfc_get((void*)spuB_, ppuA + (B*m+(n+1))*NB*NB*sizeof(float), 16384, tagB_, 0, 0);
                        B_pulled = 1;
                    }
                }


                // Prefetch T
                if (!T_pulled)
                {
                    T_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*step+step]) == 1)
                    {
                        mfc_get((void*)spuT, ppuA + (B*step+step)*NB*NB*sizeof(float), 16384, tagT, 0, 0);
                        T_pulled = 1;
                    }
                }


                wait(A);
                wait(B);


                // Do SGEMM
                int color = 100 + (step%3) * 50;
                start = spu_read_decrementer();
                spu_sgemm_tile((float*)spuB, (float*)spuA, (float*)spuC);
                end = spu_read_decrementer();
                spu_log_event(start, end, ((140 + 50*(step%2)) << 8) + (100 + 20*(step%2)));


                swap(A);
                swap(B);

            }


            // Prefetch A
            if (next_step < B)
            {
                A_pulled = 0;
                if ((volatile unsigned char)(spu_progress[B*next_step]) != 0)
                {
                    mfc_get((void*)spuA, ppuA + (B*next_step)*NB*NB*sizeof(float), 16384, tagA, 0, 0);
                    A_pulled = 1;
                }
            }


            // Prefetch B
            if (next_step < B && next_step != next_m)
            {
                B_pulled = 0;
                if ((volatile unsigned char)(spu_progress[B*next_m]) != 0)
                {
                    mfc_get((void*)spuB, ppuA + (B*next_m)*NB*NB*sizeof(float), 16384, tagB, 0, 0);
                    B_pulled = 1;
                }
            }


            // Pull T
            if (!T_pulled)
            {
                while ((volatile unsigned char)(spu_progress[B*step+step]) != 1);
                mfc_get((void*)spuT, ppuA + (B*step+step)*NB*NB*sizeof(float), 16384, tagT, 0, 0);
            }


            // Prefetch T
            if (next_step < B)
            {
                if (next_step == next_m)
                {
                    mfc_getf((void*)spuT_, ppuA + (B*next_step+next_step)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);
                    T_pulled = 1;
                }
                else
                {
                    T_pulled = 0;
                    if ((volatile unsigned char)(spu_progress[B*next_step+next_step]) == 1)
                    {
                        mfc_getf((void*)spuT_, ppuA + (B*next_step+next_step)*NB*NB*sizeof(float), 16384, tagT_, 0, 0);
                        T_pulled = 1;
                    }
                }
            }


            wait(T);


            // Do STRSM
            start = spu_read_decrementer();
            spu_strsm_tile((float*)spuT, (float*)spuC);
            end = spu_read_decrementer();
            spu_log_event(start, end, 0xA02060);


            // Push C
            mfc_put((void*)spuC, ppuA + (B*m+step)*NB*NB*sizeof(float), 16384, tagC, 0, 0);


            // Update progress - STRSM
            for (spu = 0; spu < spus_num; spu++)
                 mfc_putf(
                    (void*)&spu_progress_src[B*m+step & 0x0F],
                    spu_global_params.local_store[spu] + (unsigned int)(&spu_progress[B*m+step]),
                    sizeof(unsigned char), tagC, 0, 0);


            swap(T);
            swap(C);

        }

        m = next_m;
        step = next_step;

    }

}

//----------------------------------------------------------------------------------------------
