//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <libspe.h>
#include <sched.h>
#include <cbe_mfc.h>
#include <time.h>

#include "../cbe/cbe_blas.h"
#include "ppu_timer.h"
#include "ppu_blas.h"

//----------------------------------------------------------------------------------------------

extern spe_program_handle_t spu_blas;
extern double *DGEMM_WORK;;

//----------------------------------------------------------------------------------------------

static int          spus_num;
static spe_gid_t    spu_group_id;
static speid_t      spu_id[SPUS_MAX];

GlobalParams    ppu_global_params   __attribute__ ((aligned (128)));
CallArgs        ppu_call_args       __attribute__ ((aligned (128)));

volatile int    ppu_spu_ack[SPUS_MAX][4]    __attribute__ ((aligned (128)));

int ppu_event_num = 0;
int ppu_event_log[       1*1024]    __attribute__ ((aligned (128)));
int spu_event_log[SPUS_MAX*1024]    __attribute__ ((aligned (128)));

//----------------------------------------------------------------------------------------------

void ppu_spus_init(int num, int NB, int DNB)
{
    int i;


    spus_num = num;


    // Create SPU group
    printf("\n\tCreating SPEs Group ----------------> "); fflush(stdout);
    assert(spu_group_id = spe_create_group(SCHED_OTHER, 0, 1));
    printf("DONE!\n"); fflush(stdout);


    // Create SPU threads
    printf("\tLaunchning SPEs Threads ------------> "); fflush(stdout);
    for (i = 0; i < spus_num; i++)
        assert(spu_id[i] =
            spe_create_thread(spu_group_id, &spu_blas, &ppu_global_params, 0, -1, 0));
    printf("DONE!\n"); fflush(stdout);


    // Pass control structures
    printf("\tSending SPEs Init. Params. ---------> "); fflush(stdout);
    
    ppu_global_params.ppu_call_args_p = (unsigned int)&ppu_call_args;
    ppu_global_params.spus_num = spus_num;
    ppu_global_params.NB = NB;
    ppu_global_params.DNB = DNB;

    for (i = 0; i < spus_num; i++)
        ppu_global_params.ppu_spu_ack_p[i] = (unsigned int)&ppu_spu_ack[i][0];

    for(i = 0; i < spus_num; i++)
        ppu_global_params.local_store[i]    = (unsigned int)spe_get_ls(spu_id[i]);

    for(i = 0; i < spus_num; i++)
        ppu_global_params.spu_event_log[i] = (unsigned int)&spu_event_log[i*1024];

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], i);

    printf("DONE!\n\n"); fflush(stdout);
}

//----------------------------------------------------------------------------------------------

void ppu_spus_finish()
{
    int i;
    char trace_file_name[32];


    // Send termination message
    printf("\n\tSending SPEs Termination Command --->  "); fflush(stdout);
    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], END);
    printf("DONE!\n"); fflush(stdout);


    // Wait for termination
    printf("\tWaiting for SPEs Terminations ------>  "); fflush(stdout);
    for (i = 0; i < spus_num; i++)
        spe_wait(spu_id[i], NULL, 0);
    printf("DONE!\n"); fflush(stdout);


    // Dumping the trace
    sprintf(trace_file_name, "trace_%d.svg", time(NULL));
    printf("\tDumping %s ------->  ", trace_file_name); fflush(stdout);
    if (dump_trace(trace_file_name) < 0)
        printf("Trace buffer overflow -> Partial trace dupmed!\n");
    else
        printf("DONE!\n"); fflush(stdout);
    fflush(stdout);
}

//----------------------------------------------------------------------------------------------

int dump_trace(char trace_file_name[])
{
    int spu;
    int event;

    FILE *trace_file;


    trace_file = fopen(trace_file_name, "w");
    assert(trace_file != NULL);

    fprintf(trace_file,
        "<svg width=\"100mm\" height=\"20mm\" viewBox=\"0 0 10000 2000\">\n"
        "  <g>\n");

    for (spu = 0; spu < spus_num; spu++)
    {
        for (event = 0; event < 1024;)
        {
            int start = -spu_event_log[spu*1024 + event+0];
            int end   = -spu_event_log[spu*1024 + event+1];
            int color =  spu_event_log[spu*1024 + event+2];           
            
            if (color == 0)
                break;
            else
                event += 4;

            fprintf(trace_file,
                "    "
                "<rect x=\"%.2lf\" y=\"%.0lf\" width=\"%.2lf\" height=\"%.0lf\" "
                "fill=\"#%06x\" stroke=\"%x\" stroke-width=\"1\"/>\n",
                0.0 + start * 0.05,
                25.0 * spu,
                (end - start) * 0.05,
                22.0,
                color,
                0x000000);
        }
        fprintf(trace_file, "\n");
    }
    fprintf(trace_file,
        "  </g>\n"
        "</svg>\n");

    fclose(trace_file);
    
    if (event >= 1024)
        return (-1);
    else
        return (0);           
}

//----------------------------------------------------------------------------------------------

void ppu_spotrf(int M, int N, float *A, int LDA, int *INFO)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.potrf.M = M;
    ppu_call_args.potrf.N = N;

    ppu_call_args.potrf.A    = (unsigned int)A;
    ppu_call_args.potrf.LDA  = LDA;
    ppu_call_args.potrf.INFO = (unsigned int)INFO;


    start = get_current_time();
    
    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], SPOTRF);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());


    printf("\t\tspotrf:\ttime:%9.0lf\tgflops:%7.2lf\n",
        elapsed * 1000000,
        ((double)0.33*N*N*N + 0.5*N*N + 0.17*N) / elapsed / 1000000000);
}

//----------------------------------------------------------------------------------------------

void ppu_dgemm(int M, int N, int K, double *A, int LDA, double *B, int LDB, double *C, int LDC)
{
    TimeStruct start;
    double elapsed;

    int i, spu;


    ppu_call_args.gemm.M = M;
    ppu_call_args.gemm.N = N;
    ppu_call_args.gemm.K = K;

    ppu_call_args.gemm.A   = (unsigned int)A;
    ppu_call_args.gemm.LDA = LDA;
    ppu_call_args.gemm.B   = (unsigned int)B;
    ppu_call_args.gemm.LDB = LDB;
    ppu_call_args.gemm.C   = (unsigned int)DGEMM_WORK;
    ppu_call_args.gemm.LDC = LDC;

    start = get_current_time();

    for (i = 0; i < M*spus_num; i++)
       DGEMM_WORK[i] = 0.0;

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], DGEMM);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    for (spu = 0; spu < spus_num; spu++)
        for (i = 0; i < M; i++)
            C[i] = C[i] - DGEMM_WORK[spu*M + i];
    
    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tdgemm:\ttime:%9.0lf\tgflops:%7.2lf\n",
        elapsed * 1000000,
        2.0*M*N*K / elapsed / 1000000000);
}

//----------------------------------------------------------------------------------------------

void ppu_strsm_no_trans(int M, int N, float *T, int LDT, float *B, int LDB)
{
    TimeStruct start;
    double elapsed;

    int i, spu;


    ppu_call_args.trsm.M = M;
    ppu_call_args.trsm.N = N;

    ppu_call_args.trsm.T   = (unsigned int)T;
    ppu_call_args.trsm.LDT = LDT;
    ppu_call_args.trsm.B   = (unsigned int)B;
    ppu_call_args.trsm.LDB = LDB;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], STRSM_NOTRANS);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);
    
    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tstrsm:\ttime:%9.0lf\tgflops:%7.2lf\n",
        elapsed * 1000000,
        M*M*N / elapsed / 1000000000);
}

//----------------------------------------------------------------------------------------------

void ppu_strsm_trans(int M, int N, float *T, int LDT, float *B, int LDB)
{
    TimeStruct start;
    double elapsed;

    int i, spu;


    ppu_call_args.trsm.M = M;
    ppu_call_args.trsm.N = N;

    ppu_call_args.trsm.T   = (unsigned int)T;
    ppu_call_args.trsm.LDT = LDT;
    ppu_call_args.trsm.B   = (unsigned int)B;
    ppu_call_args.trsm.LDB = LDB;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], STRSM_TRANS);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);
    
    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tstrsm:\ttime:%9.0lf\tgflops:%7.2lf\n",
        elapsed * 1000000,
        M*M*N / elapsed / 1000000000);
}

//----------------------------------------------------------------------------------------------

void ppu_lapack2blocked(int M, int N, float *A, float *B)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.blocked_format.M = M;
    ppu_call_args.blocked_format.N = N;

    ppu_call_args.blocked_format.A = (unsigned int)A;
    ppu_call_args.blocked_format.B = (unsigned int)B;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], LAPACK2BLOCKED);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tl2b:\ttime:%9.0lf\n", elapsed * 1000000);
}

//----------------------------------------------------------------------------------------------

void ppu_blocked2lapack(int M, int N, float *A, float *B)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.blocked_format.M = M;
    ppu_call_args.blocked_format.N = N;

    ppu_call_args.blocked_format.A = (unsigned int)A;
    ppu_call_args.blocked_format.B = (unsigned int)B;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], BLOCKED2LAPACK);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tb2l:\ttime:%9.0lf\n", elapsed * 1000000);
}

//----------------------------------------------------------------------------------------------

void ppu_lapack2blocked_double(int M, int N, double *A, double *B)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.blocked_format.M = M;
    ppu_call_args.blocked_format.N = N;

    ppu_call_args.blocked_format.A = (unsigned int)A;
    ppu_call_args.blocked_format.B = (unsigned int)B;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], LAPACK2BLOCKED_DOUBLE);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tl2b_d:\ttime:%9.0lf\n", elapsed * 1000000);
}

//----------------------------------------------------------------------------------------------

void ppu_blocked2lapack_double(int M, int N, double *A, double *B)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.blocked_format.M = M;
    ppu_call_args.blocked_format.N = N;

    ppu_call_args.blocked_format.A = (unsigned int)A;
    ppu_call_args.blocked_format.B = (unsigned int)B;

    start = get_current_time();

    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], BLOCKED2LAPACK_DOUBLE);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());

    printf("\t\tb2l_d:\ttime:%9.0lf\n", elapsed * 1000000);
}

//----------------------------------------------------------------------------------------------

void ppu_convert_d2s(int M, int N, double *D, float *S)
{
    TimeStruct start, end;
    double elapsed;

    int i;


    ppu_call_args.convert.M = M;
    ppu_call_args.convert.N = N;

    ppu_call_args.convert.D = (unsigned int)D;
    ppu_call_args.convert.S = (unsigned int)S;

    start = get_current_time();
    
    for (i = 0; i < spus_num; i++)
        ppu_spu_ack[i][0] = 0;

    for (i = 0; i < spus_num; i++)
        spe_write_in_mbox(spu_id[i], CONVERT_D2S);

    for (i = 0; i < spus_num; i++)
        while (ppu_spu_ack[i][0] == 0);

    elapsed = get_elapsed_time(start, get_current_time());
    printf("\t\td2s:\ttime:%9.0lf\n", elapsed * 1000000);
}

//----------------------------------------------------------------------------------------------
