//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <cbe_mfc.h>
#include <spu_mfcio.h>
#include <vec_literal.h>
#include <spu_intrinsics.h>

#include "spu_blas.h"
#include "spu_blocked_format.h"

//----------------------------------------------------------------------------------------------

extern int NB;
extern int DNB;
extern int spus_num;
extern int my_spu_id;
extern CallArgs spu_call_args __attribute__ ((aligned (128)));

extern float  *buffer[];
extern double *dbuffer[];

//----------------------------------------------------------------------------------------------

void spu_lapack2blocked()
{
    unsigned int A, B;

    int recv_tile, send_tile;
    int recv_buff, send_buff;
    int send_buff_;

    int M, N, K;
    int Bm, Bn;
    int m, n;
    

    spu_recv_call_args();


    M = spu_call_args.blocked_format.M;
    N = spu_call_args.blocked_format.N;

    Bm = M/NB;
    Bn = N/NB;
    
    recv_tile = my_spu_id;
    send_tile = my_spu_id;


    // Initial receives
    recv_buff = 0;
    while (recv_tile <= Bm*Bn-1 && recv_buff <= BUFFERS/2-1)
    {
        m =   (recv_tile) % Bm;
        n =  ((recv_tile) / Bm) % Bn;
        A = spu_call_args.blocked_format.A + (m*NB + n*NB*N) * sizeof(float);

        spu_recv_tile(buffer[recv_buff], A, N, recv_buff);

        recv_tile += spus_num;
        recv_buff++;
    }

    // Pipelined operation
    while (recv_tile <= Bm*Bn-1)
    {
        send_buff  = 0;
        send_buff_ = 4;
        while (send_tile <= Bm*Bn-1 && send_buff <= BUFFERS/2-1)
        {
            m =   (send_tile) % Bm;
            n =  ((send_tile) / Bm) % Bn;
            B = spu_call_args.blocked_format.B + (n*NB*NB + m*NB*N) * sizeof(float);

            spu_wait_tag(send_buff);
            spu_trans_tile(buffer[send_buff], buffer[send_buff+BUFFERS/2]);
            spu_send_tile_blocked(buffer[send_buff+BUFFERS/2], B, send_buff);

            send_tile += spus_num;
            send_buff ++;
            send_buff_++;
        }

        recv_buff  = 0;
        while (recv_tile <= Bm*Bn-1 && recv_buff <= BUFFERS/2-1)
        {
            m =   (recv_tile) % Bm;
            n =  ((recv_tile) / Bm) % Bn;
            A = spu_call_args.blocked_format.A + (m*NB + n*NB*N) * sizeof(float);

            spu_wait_tag(recv_buff);
            spu_recv_tile(buffer[recv_buff], A, N, recv_buff);

            recv_tile += spus_num;
            recv_buff++;
        }
    }

    // Final sends
    send_buff = 0;
    while (send_tile <= Bm*Bn-1 && send_buff <= BUFFERS/2-1)
    {
        m =   (send_tile) % Bm;
        n =  ((send_tile) / Bm) % Bn;
        B = spu_call_args.blocked_format.B + (n*NB*NB + m*NB*N) * sizeof(float);

        spu_wait_tag(send_buff);
        spu_trans_tile(buffer[send_buff], buffer[send_buff+BUFFERS/2]);
        spu_send_tile_blocked(buffer[send_buff+BUFFERS/2], B, 1);

        send_tile += spus_num;
        send_buff++;
    }

    spu_wait_tag(1);
}

//----------------------------------------------------------------------------------------------

void spu_blocked2lapack()
{
    unsigned int A, B;

    int recv_tile, send_tile;
    int recv_buff, send_buff;

    int M, N, K;
    int Bm, Bn;
    int m, n;
    

    spu_recv_call_args();


    M = spu_call_args.blocked_format.M;
    N = spu_call_args.blocked_format.N;

    Bm = M/NB;
    Bn = N/NB;
    
    recv_tile = my_spu_id;
    send_tile = my_spu_id;


    // Initial receives
    recv_buff = 0;
    while (recv_tile <= Bm*Bn-1 && recv_buff <= BUFFERS/2-1)
    {
        m =   (recv_tile) % Bm;
        n =  ((recv_tile) / Bm) % Bn;
        A = spu_call_args.blocked_format.A + (n*NB*NB + m*NB*N) * sizeof(float);

        spu_recv_tile_blocked(buffer[recv_buff], A, recv_buff);

        recv_tile += spus_num;
        recv_buff++;
    }

    // Pipelined operation
    while (recv_tile <= Bm*Bn-1)
    {
        send_buff = 0;
        while (send_tile <= Bm*Bn-1 && send_buff <= BUFFERS/2-1)
        {
            m =   (send_tile) % Bm;
            n =  ((send_tile) / Bm) % Bn;
            B = spu_call_args.blocked_format.B + (m*NB + n*NB*N) * sizeof(float);

            spu_wait_tag(send_buff);
            spu_trans_tile(buffer[send_buff], buffer[send_buff+BUFFERS/2]);
            spu_send_tile(buffer[send_buff+BUFFERS/2], B, N, send_buff);

            send_tile += spus_num;
            send_buff++;
        }

        recv_buff  = 0;
        while (recv_tile <= Bm*Bn-1 && recv_buff <= BUFFERS/2-1)
        {
            m =   (recv_tile) % Bm;
            n =  ((recv_tile) / Bm) % Bn;
            A = spu_call_args.blocked_format.A + (n*NB*NB + m*NB*N) * sizeof(float);

            spu_wait_tag(recv_buff);
            spu_recv_tile_blocked(buffer[recv_buff], A, recv_buff);

            recv_tile += spus_num;
            recv_buff++;
        }
    }

    // Final sends
    send_buff = 0;
    while (send_tile <= Bm*Bn-1 && send_buff <= BUFFERS/2-1)
    {
        m =   (send_tile) % Bm;
        n =  ((send_tile) / Bm) % Bn;
        B = spu_call_args.blocked_format.B + (m*NB + n*NB*N) * sizeof(float);

        spu_wait_tag(send_buff);
        spu_trans_tile(buffer[send_buff], buffer[send_buff+BUFFERS/2]);
        spu_send_tile(buffer[send_buff+BUFFERS/2], B, N, 1);

        send_tile += spus_num;
        send_buff++;
    }

    spu_wait_tag(1);
}

//----------------------------------------------------------------------------------------------

void spu_trans_tile(float *A, float *B)
{
    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    start = spu_read_decrementer();

    transpose_matrix(NB, NB, A, NB, B, NB);

    end = spu_read_decrementer();
    spu_log_event(start, end, 0x60A0E0);
}

//----------------------------------------------------------------------------------------------

void transpose_matrix(int m, int n, float* a, int lda, float* b, int ldb)
{
    int i, j;

    vector float abcd, efgh, ijkl, mnop;	// input vectors
    vector float aeim, bfjn, cgko, dhlp;	// output vectors
    vector float aibj, ckdl, emfn, gohp;	// intermediate vectors

    vector unsigned char shufflehi = VEC_LITERAL(vector unsigned char,
        0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13,
        0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17);

    vector unsigned char shufflelo = VEC_LITERAL(vector unsigned char,
        0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B,
        0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F);

    for (i = 0; i < m; i += 4)
    {
        for (j = 0; j < n; j += 4)
        {
            abcd = *((vector float*) (&a[(i  )*lda + j]));
            efgh = *((vector float*) (&a[(i+1)*lda + j]));
            ijkl = *((vector float*) (&a[(i+2)*lda + j]));
            mnop = *((vector float*) (&a[(i+3)*lda + j]));

            aibj = spu_shuffle(abcd, ijkl, shufflehi);
            ckdl = spu_shuffle(abcd, ijkl, shufflelo);
            emfn = spu_shuffle(efgh, mnop, shufflehi);
            gohp = spu_shuffle(efgh, mnop, shufflelo);

            aeim = spu_shuffle(aibj, emfn, shufflehi);
            bfjn = spu_shuffle(aibj, emfn, shufflelo);
            cgko = spu_shuffle(ckdl, gohp, shufflehi);
            dhlp = spu_shuffle(ckdl, gohp, shufflelo);

            *((vector float*) (&b[(j  )*ldb + i])) = aeim;
            *((vector float*) (&b[(j+1)*ldb + i])) = bfjn;
            *((vector float*) (&b[(j+2)*ldb + i])) = cgko;
            *((vector float*) (&b[(j+3)*ldb + i])) = dhlp;
        }
    }
}

//----------------------------------------------------------------------------------------------

void spu_lapack2blocked_double()
{
    unsigned int A, B;

    int recv_tile, send_tile;
    int recv_buff, send_buff;
    int send_buff_;

    int M, N, K;
    int Bm, Bn;
    int m, n;


    spu_recv_call_args();


    M = spu_call_args.blocked_format.M;
    N = spu_call_args.blocked_format.N;

    Bm = M/DNB;
    Bn = N/DNB;
    
    recv_tile = my_spu_id;
    send_tile = my_spu_id;


    // Initial receives
    recv_buff = 0;
    while (recv_tile <= Bm*Bn-1 && recv_buff <= DBUFFERS/2-1)
    {
        m =   (recv_tile) % Bm;
        n =  ((recv_tile) / Bm) % Bn;
        A = spu_call_args.blocked_format.A + (m*DNB + n*DNB*N) * sizeof(double);

        spu_recv_tile_double(dbuffer[recv_buff], A, N, recv_buff);

        recv_tile += spus_num;
        recv_buff++;
    }

    // Pipelined operation
    while (recv_tile <= Bm*Bn-1)
    {
        send_buff = 0;
        send_buff_ = 4;
        while (send_tile <= Bm*Bn-1 && send_buff <= DBUFFERS/2-1)
        {
            m =   (send_tile) % Bm;
            n =  ((send_tile) / Bm) % Bn;
            B = spu_call_args.blocked_format.B + (n*DNB*DNB + m*DNB*N) * sizeof(double);

            spu_wait_tag(send_buff);
            spu_trans_tile_double(dbuffer[send_buff], dbuffer[send_buff+DBUFFERS/2]);
            spu_send_tile_blocked_double(dbuffer[send_buff+DBUFFERS/2], B, send_buff);

            send_tile += spus_num;
            send_buff ++;
            send_buff_++;
        }

        recv_buff  = 0;
        while (recv_tile <= Bm*Bn-1 && recv_buff <= DBUFFERS/2-1)
        {
            m =   (recv_tile) % Bm;
            n =  ((recv_tile) / Bm) % Bn;
            A = spu_call_args.blocked_format.A + (m*DNB + n*DNB*N) * sizeof(double);

            spu_wait_tag(recv_buff);
            spu_recv_tile_double(dbuffer[recv_buff], A, N, recv_buff);

            recv_tile += spus_num;
            recv_buff++;
        }
    }

    // Final sends
    send_buff = 0;
    while (send_tile <= Bm*Bn-1 && send_buff <= DBUFFERS/2-1)
    {
        m =   (send_tile) % Bm;
        n =  ((send_tile) / Bm) % Bn;
        B = spu_call_args.blocked_format.B + (n*DNB*DNB + m*DNB*N) * sizeof(double);

        spu_wait_tag(send_buff);
        spu_trans_tile_double(dbuffer[send_buff], dbuffer[send_buff+DBUFFERS/2]);
        spu_send_tile_blocked_double(dbuffer[send_buff+DBUFFERS/2], B, 1);

        send_tile += spus_num;
        send_buff++;
    }

    spu_wait_tag(1);
}

//----------------------------------------------------------------------------------------------

void spu_blocked2lapack_double()
{
    unsigned int A, B;

    int recv_tile, send_tile;
    int recv_buff, send_buff;

    int M, N, K;
    int Bm, Bn;
    int m, n;
    

    spu_recv_call_args();


    M = spu_call_args.blocked_format.M;
    N = spu_call_args.blocked_format.N;

    Bm = M/DNB;
    Bn = N/DNB;
    
    recv_tile = my_spu_id;
    send_tile = my_spu_id;


    // Initial receives
    recv_buff = 0;
    while (recv_tile <= Bm*Bn-1 && recv_buff <= DBUFFERS/2-1)
    {
        m =   (recv_tile) % Bm;
        n =  ((recv_tile) / Bm) % Bn;
        A = spu_call_args.blocked_format.A + (n*DNB*DNB + m*DNB*N) * sizeof(double);

        spu_recv_tile_blocked_double(dbuffer[recv_buff], A, recv_buff);

        recv_tile += spus_num;
        recv_buff++;
    }

    // Pipelined operation
    while (recv_tile <= Bm*Bn-1)
    {
        send_buff = 0;
        while (send_tile <= Bm*Bn-1 && send_buff <= DBUFFERS/2-1)
        {
            m =   (send_tile) % Bm;
            n =  ((send_tile) / Bm) % Bn;
            B = spu_call_args.blocked_format.B + (m*DNB + n*DNB*N) * sizeof(double);

            spu_wait_tag(send_buff);
            spu_trans_tile_double(dbuffer[send_buff], dbuffer[send_buff+DBUFFERS/2]);
            spu_send_tile_double(dbuffer[send_buff+DBUFFERS/2], B, N, send_buff);

            send_tile += spus_num;
            send_buff++;
        }

        recv_buff  = 0;
        while (recv_tile <= Bm*Bn-1 && recv_buff <= DBUFFERS/2-1)
        {
            m =   (recv_tile) % Bm;
            n =  ((recv_tile) / Bm) % Bn;
            A = spu_call_args.blocked_format.A + (n*DNB*DNB + m*DNB*N) * sizeof(double);

            spu_wait_tag(recv_buff);
            spu_recv_tile_blocked_double(dbuffer[recv_buff], A, recv_buff);

            recv_tile += spus_num;
            recv_buff++;
        }
    }

    // Final sends
    send_buff = 0;
    while (send_tile <= Bm*Bn-1 && send_buff <= DBUFFERS/2-1)
    {
        m =   (send_tile) % Bm;
        n =  ((send_tile) / Bm) % Bn;
        B = spu_call_args.blocked_format.B + (m*DNB + n*DNB*N) * sizeof(double);

        spu_wait_tag(send_buff);
        spu_trans_tile_double(dbuffer[send_buff], dbuffer[send_buff+DBUFFERS/2]);
        spu_send_tile_double(dbuffer[send_buff+DBUFFERS/2], B, N, 1);

        send_tile += spus_num;
        send_buff++;
    }

    spu_wait_tag(1);
}

//----------------------------------------------------------------------------------------------

void spu_trans_tile_double(double *A, double *B)
{
    //----------------------------------------------------------

    extern int spu_event_num;
    extern int spu_event_log[];

    #define spu_log_event(start, end, event)\
        spu_event_log[spu_event_num+0] = start;\
        spu_event_log[spu_event_num+1] = end;\
        spu_event_log[spu_event_num+2] = event;\
        spu_event_num += 4;\
        spu_event_num &= 1024-1;\

    int start;
    int end;

    //----------------------------------------------------------

    start = spu_read_decrementer();

    transpose_matrix_double(DNB, DNB, A, DNB, B, DNB);

    end = spu_read_decrementer();
    spu_log_event(start, end, 0x60A0E0);
}

//----------------------------------------------------------------------------------------------

void transpose_matrix_double(int m, int n, double* a, int lda, double* b, int ldb)
{
    int i, j;
    
    vector double ab, cd;
    vector double ac, bd;

    vector unsigned char shufflelo = VEC_LITERAL(vector unsigned char,
        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
        0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17);

    vector unsigned char shufflehi = VEC_LITERAL(vector unsigned char,
        0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
        0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F);


    for (i = 0; i < m; i += 2)
    {
        for (j = 0; j < n; j += 2)
        {
            ab = *((vector double*) (&a[(i  )*lda + j]));
            cd = *((vector double*) (&a[(i+1)*lda + j]));

            ac = spu_shuffle(ab, cd, shufflelo);
            bd = spu_shuffle(ab, cd, shufflehi);
            
            *((vector double*) (&b[(j  )*ldb + i])) = ac;
            *((vector double*) (&b[(j+1)*ldb + i])) = bd;
        }
    }
}

//----------------------------------------------------------------------------------------------
