#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <cblas.h>
#include <lapack.h>
#include <plasma.h>
#include "auxiliary.h"

/*-------------------------------------------------------------------
 * Check the orthogonality of Q
 */

int scheck_orthogonality(int M, int N, int LDQ, float *Q)
{
    float alpha, beta;
    float normQ;
    int info_ortho;
    int i;
    int minMN = min(M, N);
    float eps;
    float *work = (float *)malloc(minMN*sizeof(float));

    eps = lapack_slamch(lapack_eps);
    alpha = 1.0;
    beta  = -1.0;

    /* Build the idendity matrix USE DLASET?*/
    float *Id = (float *) malloc(minMN*minMN*sizeof(float));
    memset((void*)Id, 0, minMN*minMN*sizeof(float));
    for (i = 0; i < minMN; i++)
        Id[i*minMN+i] = (float)1.0;

    /* Perform Id - Q'Q */
    if (M >= N)
        cblas_ssyrk(CblasColMajor, CblasUpper, CblasTrans, N, M, alpha, Q, LDQ, beta, Id, N);
    else
        cblas_ssyrk(CblasColMajor, CblasUpper, CblasNoTrans, M, N, alpha, Q, LDQ, beta, Id, M);

    normQ = lapack_slansy(lapack_inf_norm, (enum lapack_uplo_type)PlasmaUpper, minMN, Id, minMN, work);

    printf("============\n");
    printf("Checking the orthogonality of Q \n");
    printf("||Id-Q'*Q||_oo / (N*eps) = %e \n",normQ/(minMN*eps));

    if ( isnan(normQ / (minMN * eps)) || (normQ / (minMN * eps) > 10.0) ) {
        printf("-- Orthogonality is suspicious ! \n");
        info_ortho=1;
    }
    else {
        printf("-- Orthogonality is CORRECT ! \n");
        info_ortho=0;
    }

    free(work); free(Id);

    return info_ortho;
}

/*------------------------------------------------------------
 *  Check the factorization QR
 */

int scheck_QRfactorization(int M, int N, float *A1, float *A2, int LDA, float *Q)
{
    float Anorm, Rnorm;
    float alpha, beta;
    int info_factorization;
    int i,j;
    float eps;

    eps = lapack_slamch(lapack_eps);

    float *Ql       = (float *)malloc(M*N*sizeof(float));
    float *Residual = (float *)malloc(M*N*sizeof(float));
    float *work              = (float *)malloc(max(M,N)*sizeof(float));

    alpha=1.0;
    beta=0.0;

    if (M >= N) {
        /* Extract the R */
        float *R = (float *)malloc(N*N*sizeof(float));
        memset((void*)R, 0, N*N*sizeof(float));
        lapack_slacpy((enum lapack_uplo_type)PlasmaUpper, M, N, A2, LDA, R, N);

        /* Perform Ql=Q*R */
        memset((void*)Ql, 0, M*N*sizeof(float));
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, N, (alpha), Q, LDA, R, N, (beta), Ql, M);
        free(R);
    }
    else {
        /* Extract the L */
        float *L = (float *)malloc(M*M*sizeof(float));
        memset((void*)L, 0, M*M*sizeof(float));
        lapack_slacpy((enum lapack_uplo_type)PlasmaLower, M, N, A2, LDA, L, M);

    /* Perform Ql=LQ */
        memset((void*)Ql, 0, M*N*sizeof(float));
        cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, M, (alpha), L, M, Q, LDA, (beta), Ql, M);
        free(L);
    }

    /* Compute the Residual */
    for (i = 0; i < M; i++)
        for (j = 0 ; j < N; j++)
            Residual[j*M+i] = A1[j*LDA+i]-Ql[j*M+i];

    Rnorm = lapack_slange(lapack_inf_norm, M, N, Residual, M, work);
    Anorm = lapack_slange(lapack_inf_norm, M, N, A2, LDA, work);

    if (M >= N) {
        printf("============\n");
        printf("Checking the QR Factorization \n");
        printf("-- ||A-QR||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));
    }
    else {
        printf("============\n");
        printf("Checking the LQ Factorization \n");
        printf("-- ||A-LQ||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));
    }

    if (isnan(Rnorm / (Anorm * N *eps)) || (Rnorm / (Anorm * N * eps) > 10.0) ) {
        printf("-- Factorization is suspicious ! \n");
        info_factorization = 1;
    }
    else {
        printf("-- Factorization is CORRECT ! \n");
        info_factorization = 0;
    }

    free(work); free(Ql); free(Residual);

    return info_factorization;
}

/*------------------------------------------------------------------------
 *  Check the factorization of the matrix A2
 */

int scheck_LLTfactorization(int N, float *A1, float *A2, int LDA, int uplo)
{
    float Anorm, Rnorm;
    float alpha;
    int info_factorization;
    int i,j;
    float eps;

    eps = lapack_slamch(lapack_eps);

    float *Residual = (float *)malloc(N*N*sizeof(float));
    float *L1       = (float *)malloc(N*N*sizeof(float));
    float *L2       = (float *)malloc(N*N*sizeof(float));
    float *work              = (float *)malloc(N*sizeof(float));

    memset((void*)L1, 0, N*N*sizeof(float));
    memset((void*)L2, 0, N*N*sizeof(float));

    alpha= 1.0;

    lapack_slacpy(lapack_upper_lower, N, N, A1, LDA, Residual, N);

    /* Dealing with L'L or U'U  */
    if (uplo == PlasmaUpper){
        lapack_slacpy((enum lapack_uplo_type)PlasmaUpper, N, N, A2, LDA, L1, N);
        lapack_slacpy((enum lapack_uplo_type)PlasmaUpper, N, N, A2, LDA, L2, N);
        cblas_strmm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit, N, N, (alpha), L1, N, L2, N);
    }
    else{
        lapack_slacpy((enum lapack_uplo_type)PlasmaLower, N, N, A2, LDA, L1, N);
        lapack_slacpy((enum lapack_uplo_type)PlasmaLower, N, N, A2, LDA, L2, N);
        cblas_strmm(CblasColMajor, CblasRight, CblasLower, CblasTrans, CblasNonUnit, N, N, (alpha), L1, N, L2, N);
    }

    /* Compute the Residual || A -L'L|| */
    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
           Residual[j*N+i] = L2[j*N+i] - Residual[j*N+i];

    Rnorm = lapack_slange(lapack_inf_norm, N, N, Residual, N, work);
    Anorm = lapack_slange(lapack_inf_norm, N, N, A1, LDA, work);

    printf("============\n");
    printf("Checking the Cholesky Factorization \n");
    printf("-- ||L'L-A||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));

    if ( isnan(Rnorm/(Anorm*N*eps)) || (Rnorm/(Anorm*N*eps) > 10.0) ){
        printf("-- Factorization is suspicious ! \n");
        info_factorization = 1;
    }
    else{
        printf("-- Factorization is CORRECT ! \n");
        info_factorization = 0;
    }

    free(Residual); free(L1); free(L2); free(work);

    return info_factorization;
}

/*--------------------------------------------------------------
 * Check the gemm
 */
float scheck_gemm(PLASMA_enum transA, PLASMA_enum transB, int M, int N, int K,
                   float alpha, float *A, int LDA, 
                   float *B, int LDB, 
                   float beta, float *Cplasma, 
                   float *Cref, int LDC,
                   float *Cinitnorm, float *Cplasmanorm, float *Clapacknorm )
{
    float beta_const = -1.0;
    float Rnorm;
    float *work = (float *)malloc(max(K,max(M, N))* sizeof(float));

    *Cinitnorm   = lapack_slange(lapack_inf_norm, M, N, Cref,    LDC, work);
    *Cplasmanorm = lapack_slange(lapack_inf_norm, M, N, Cplasma, LDC, work);

    cblas_sgemm(CblasColMajor, (enum CBLAS_TRANSPOSE)transA, (enum CBLAS_TRANSPOSE)transB, M, N, K, 
                (alpha), A, LDA, B, LDB, (beta), Cref, LDC);

    *Clapacknorm = lapack_slange(lapack_inf_norm, M, N, Cref, LDC, work);

    cblas_saxpy(LDC * N, (beta_const), Cplasma, 1, Cref, 1);

    Rnorm = lapack_slange(lapack_inf_norm, M, N, Cref, LDC, work);

    free(work);

    return Rnorm;
}

/*--------------------------------------------------------------
 * Check the solution
 */

float scheck_solution(int M, int N, int NRHS, float *A, int LDA, 
                      float *B,  float *X, int LDB,
                      float *anorm, float *bnorm, float *xnorm )
{
/*     int info_solution; */
    float Rnorm = -1.00;
    float zone  =  1.0;
    float mzone = -1.0;
    float *work = (float *)malloc(max(M, N)* sizeof(float));

    *anorm = lapack_slange(lapack_inf_norm, M, N,    A, LDA, work);
    *xnorm = lapack_slange(lapack_inf_norm, M, NRHS, X, LDB, work);
    *bnorm = lapack_slange(lapack_inf_norm, N, NRHS, B, LDB, work);

    cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, NRHS, N, (zone), A, LDA, X, LDB, (mzone), B, LDB);
    
    Rnorm = lapack_slange(lapack_inf_norm, N, NRHS, B, M, work);

/*     if (M >= N) { */
/*        float *Residual = (float *)malloc(M*NRHS*sizeof(float)); */
/*        memset((void*)Residual, 0, M*NRHS*sizeof(float)); */
/*        cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, N, NRHS, M, (alpha), A, LDA, B, LDB, (beta), Residual, M); */
/*        Rnorm = lapack_slange(lapack_inf_norm, M, NRHS, Residual, M, work); */
/*        free(Residual); */
/*     } */
/*     else { */
/*        float *Residual = (float *)malloc(N*NRHS*sizeof(float)); */
/*        memset((void*)Residual, 0, N*NRHS*sizeof(float)); */
/*        cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, N, NRHS, M, (alpha), A1, LDA, B1, LDB, (beta), Residual, N); */
/*        Rnorm = lapack_slange(lapack_inf_norm, N, NRHS, Residual, N, work); */
/*        free(Residual); */
/*     } */

    free(work);

    return Rnorm;
}

