//==============================================================================================
//
//  Innovative Computing Laboratory - Computer Science Department - University of Tennessee
//  Written by Jakub Kurzak
//
//==============================================================================================

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include <errno.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <cblas.h>
#include <lapack.h>

#include "../cbe/cbe_blas.h"
#include "ppu_blas.h"
#include "ppu_timer.h"
#include "iter_ref_cholesky.h"

//----------------------------------------------------------------------------------------------

double *DGEMM_WORK;

//----------------------------------------------------------------------------------------------

int orth( int m, int n, double *A, int lda);
int generate_fixed_cond_Asym( int n, double *A, int lda, int cond);

void convert_d2s(double *D, float *S, int N);
void convert_s2d(float *S, double *D, int N);

//----------------------------------------------------------------------------------------------

int main(int argc, char *argv[])
{
    double *A;          // Coefficient matrix A - double precision, standard layout
    double *Ablk;       // Coefficient matrix A - double precision, block layout
    double *B;          // Right hand side vector B in double precision
    double *X;          // Solution vector X in double precision
    double *R;          // Residual vector R in double precision
    double *Z;          // Correction vector Z in double precision

    float *A_;          // Coefficient matrix A - single precision, standard layout
    float *Ablk_;       // Coefficient matrix A - single precision, block layout
    float *B_;          // Right hand side vector B in single precision
    float *X_;          // Solution vector X in single precision
    float *R_;          // Residual vector R in single precision
    float *Z_;          // Correction vector Z in single precision

    int INFO;       // Return info
    
    int NB  = 64;   // Block size in single precision
    int DNB = 32;   // Block size in double precision
    int BB;         // System size in single precision blocks
    int DBB;        // System size in double precision blocks
    int N;          // System size
    int LDA;        // Leading dimension of the coefficients matrix A
    int LDB;        // Leading dimension of the right hand side vector B
    int LDX;        // Leading dimension of the solution vector X
    int NN;         // Size of the coefficient matrix A
    int NRHS = 1;   // Number of right hand sides
    int spus_num;   // Number of SPUs
    int cond;       // log10 of the condition number
    int iter;       // Number of iterative refinement steps

    char *mem_file = "/huge/huge_page.bin";     // Huge pages file name
    char *mem_addr;                             // Huge pages memory address
    int huge_size;                              // Required huge pages memory size
    int fmem;                                   // Huge pages file handle

    int i, j, k;

    //----------------------------------------------------------

    // Process call arguments
    assert(argc == 5);

    BB       = atoi(argv[1]);
    spus_num = atoi(argv[2]);
    cond     = atoi(argv[3]);
    iter     = atoi(argv[4]);

    DBB  = BB * 2;
    N    = NB * BB;
    LDA  = N;
    LDB  = LDA;
    LDX  = LDA;
    NN   = N*N;
    NRHS = 1;

    //----------------------------------------------------------

    // Allocate memory in large pages
    /* ----------------------> */   mem_addr = 0;
    A       = (double*)mem_addr;    mem_addr += NN  * sizeof(double);
    Ablk    = (double*)mem_addr;    mem_addr += NN  * sizeof(double);
    A_      = (float*) mem_addr;    mem_addr += NN  * sizeof(float);
    Ablk_   = (float*) mem_addr;    mem_addr += NN  * sizeof(float);

    B       = (double*)mem_addr;    mem_addr += N   * sizeof(double);
    X       = (double*)mem_addr;    mem_addr += N   * sizeof(double);
    B_      = (float*) mem_addr;    mem_addr += N   * sizeof(float);
    X_      = (float*) mem_addr;    mem_addr += N   * sizeof(float);

    R       = (double*)mem_addr;    mem_addr += N   * sizeof(double);
    Z       = (double*)mem_addr;    mem_addr += N   * sizeof(double);
    R_      = (float*) mem_addr;    mem_addr += N   * sizeof(float);
    Z_      = (float*) mem_addr;    mem_addr += N   * sizeof(float);

    DGEMM_WORK = (double*)mem_addr; mem_addr += N*SPUS_MAX * sizeof(double);
    
    
    huge_size = (int)mem_addr;
    huge_size = (huge_size + HUGE_PAGE_SIZE-1) & ~(HUGE_PAGE_SIZE-1);

    fmem = open(mem_file, O_CREAT | O_RDWR, 0755);
    assert(fmem != -1);
    remove(mem_file);

    mem_addr = (char*)mmap(0, huge_size, PROT_READ | PROT_WRITE, MAP_SHARED, fmem, 0);
    assert(mem_addr != MAP_FAILED);

    //----------------------------------------------------------

    memset(mem_addr, 0, huge_size);

    //----------------------------------------------------------

    A     = (double*)(mem_addr + (unsigned int)A    );
    Ablk  = (double*)(mem_addr + (unsigned int)Ablk );
    A_    = (float*) (mem_addr + (unsigned int)A_   );
    Ablk_ = (float*) (mem_addr + (unsigned int)Ablk_);

    B     = (double*)(mem_addr + (unsigned int)B    );
    X     = (double*)(mem_addr + (unsigned int)X    );
    B_    = (float*) (mem_addr + (unsigned int)B_   );
    X_    = (float*) (mem_addr + (unsigned int)X_   );

    R     = (double*)(mem_addr + (unsigned int)R    );
    Z     = (double*)(mem_addr + (unsigned int)Z    );
    R_    = (float*) (mem_addr + (unsigned int)R_   );
    Z_    = (float*) (mem_addr + (unsigned int)Z_   );

    DGEMM_WORK = (double*)(mem_addr + (unsigned int)DGEMM_WORK);

    //----------------------------------------------------------

    // Initialize input arrays
    printf("\n\tInitializing input arrays ----------> "); fflush(stdout);
    generate_fixed_cond_Asym(N, A, N, cond);
    for (i = 0; i < N; i++)
        B[i] = (double)rand() / (double)RAND_MAX - 0.5;
    printf("DONE!"); fflush(stdout);

    //----------------------------------------------------------

    // Fire SPUs
    ppu_spus_init(spus_num, NB, DNB);


    // Convert A to single prec.
    ppu_convert_d2s(N, N, A, A_);

    // Convert A in SP to BDL
    ppu_lapack2blocked(N, N, A_, Ablk_);

    // Convert A in DP to BDL
    ppu_lapack2blocked_double(N, N, A, Ablk);


    // Cholesky factorization of A in SP
    ppu_spotrf(N, N, Ablk_, N, &INFO);

    // x_32 <- spotrs(L, L', b_32)
    convert_d2s(B, X_, N);

    // Solve 
    ppu_strsm_no_trans(N, 1, Ablk_, N, X_, N);
    ppu_strsm_trans(N, 1, Ablk_, N, X_, N);
    
    // x <- x_32
    convert_s2d(X_, X, N);


    // Iterate
    for (j = 0; j < iter; j++)
    {
        // r <- b - Ax
        memcpy(R, B, N*sizeof(double));

        // 
        ppu_dgemm(N, NRHS, N, Ablk, LDA, X, LDX, R, LDB);

        // z_32 <- spotrs(L, L', r_32)
        convert_d2s(R, Z_, N);

        //
        ppu_strsm_no_trans(N, 1, Ablk_, N, Z_, N);
        ppu_strsm_trans(N, 1, Ablk_, N, Z_, N);

        // z <- z_32
        convert_s2d(Z_, Z, N);

        // x <- x + z
        cblas_daxpy(N, 1.0, Z, 1, X, 1);
    }


    // Honorably discharge SPUs
    ppu_spus_finish();


    // r <- b - Ax
    memcpy(R, B, N*sizeof(double));

    cblas_dgemm(
        CblasColMajor,
        CblasNoTrans, CblasNoTrans,
        N, NRHS, N,
       -1.0, A, LDA,
             X, LDX,
        1.0, R, LDB);

    printf("\n\t\tNorm-wise backward error: %le", cblas_dnrm2(N, R, 1));

    double max_res = 0.0;
    for (i = 0; i < N; i++)
        if (fabs(R[i]) > max_res)
            max_res = fabs(R[i]);

    printf("\n\t\tElem-wise backward error: %le\n\n", max_res);


    // Finish
    return (0);
}

//----------------------------------------------------------------------------------------------

int orth( int m, int n, double *A, int lda){
	int min_mn, info, lwork, ione=1;
	double *tau, *work;
/**/	
//  in_mn = min(m,n);
	min_mn = m < n ? m : n;
	tau = (double *)malloc(min_mn*sizeof(double)) ;
	if (tau==NULL){ printf("error of memory allocation tau in orth \n"); exit(0); }
/**/	
	work = (double *)malloc(sizeof(double)) ;
	if (work==NULL){ printf("error of memory allocation work in orth \n"); exit(0); }
	lwork=-1;
	lapack_dgeqrf(m, n, A, lda, tau, work, lwork, &info);
	lwork=(int) work[0];
	free(work);
	work = (double *)malloc(lwork*sizeof(double)) ;
	if (work==NULL){ printf("error of memory allocation work in orth \n"); exit(0); }
	lapack_dgeqrf(m, n, A, lda, tau, work, lwork, &info);
	free(work);
/**/	
	work = (double *)malloc(sizeof(double)) ;
	if (work==NULL){ printf("error of memory allocation work in orth \n"); exit(0); }
	lwork=-1;
	lapack_dorgqr(m, n, ione, A, lda, tau, work, lwork, &info);
	lwork=(int) work[0];
	free(work);
	work = (double *)malloc(lwork*sizeof(double)) ;
	if (work==NULL){ printf("error of memory allocation work in orth \n"); exit(0); }
	lapack_dorgqr(m, n, ione, A, lda, tau, work, lwork, &info);
	free(work);
/**/	
	free(tau);
	return 0;
}

//----------------------------------------------------------------------------------------------

int generate_fixed_cond_Asym( int n, double *A, int lda, int cond){
	int i, ione=1;
	double *U, *S, *V;
	double pas, k,dzro=0.0e0,done=1.0e0;
/**/	
	U = (double *)malloc(n*n*sizeof(double)) ;
	if (U==NULL){ printf("error of memory allocation U in generate_fixed_cond_A \n"); exit(0); }
	S = (double *)malloc(n*sizeof(double)) ;
	if (S==NULL){ printf("error of memory allocation S in generate_fixed_cond_A \n"); exit(0); }
	V = (double *)malloc(n*n*sizeof(double)) ;
	if (V==NULL){ printf("error of memory allocation V in generate_fixed_cond_A \n"); exit(0); }
/**/	
	for (i = 0; i < n*n; i++)  U[i] = ((double) rand()) / ((double) RAND_MAX) - 0.5 ;
/**/	
	orth(n,n,U,n);
	for (i = 0; i < n*n; i++)  V[i] = U[i];
/**/	
	pas = ((double) cond/(double) (n-1));
	k = 0.0;
	for (i=0;i<n;i++) { k = ((double) i)*pas; S[i] = pow (10.0,k); }
/**/	
	for (i=0; i<n; i++) dscal_(&n,&(S[i]),&(U[i*n]),&ione);
	dgemm_("N","N",&n,&n,&n,&done,U,&n,V,&n,&dzro,A,&lda);
/**/	
	free(U);
	free(S);
	free(V);
	return 0;
}

//----------------------------------------------------------------------------------------------

void convert_d2s(double *D, float *S, int N)
{
    int i;
    
    for (i = 0; i < N; i++)
        S[i] = (float)D[i];
}

//----------------------------------------------------------------------------------------------

void convert_s2d(float *S, double *D, int N)
{
    int i;
    
    for (i = 0; i < N; i++)
        D[i] = (double)S[i];
}

//----------------------------------------------------------------------------------------------
