#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <cblas.h>
#include <lapack.h>
#include <plasma.h>
#include "auxiliary.h"

/*-------------------------------------------------------------------
 * Check the orthogonality of Q
 */

int zcheck_orthogonality(int M, int N, int LDQ, PLASMA_Complex64_t *Q)
{
    double alpha, beta;
    double normQ;
    int info_ortho;
    int i;
    int minMN = min(M, N);
    double eps;
    double *work = (double *)malloc(minMN*sizeof(double));

    eps = lapack_dlamch(lapack_eps);
    alpha = 1.0;
    beta  = -1.0;

    /* Build the idendity matrix USE DLASET?*/
    PLASMA_Complex64_t *Id = (PLASMA_Complex64_t *) malloc(minMN*minMN*sizeof(PLASMA_Complex64_t));
    memset((void*)Id, 0, minMN*minMN*sizeof(PLASMA_Complex64_t));
    for (i = 0; i < minMN; i++)
        Id[i*minMN+i] = (PLASMA_Complex64_t)1.0;

    /* Perform Id - Q'Q */
    if (M >= N)
        cblas_zherk(CblasColMajor, CblasUpper, CblasConjTrans, N, M, alpha, Q, LDQ, beta, Id, N);
    else
        cblas_zherk(CblasColMajor, CblasUpper, CblasNoTrans, M, N, alpha, Q, LDQ, beta, Id, M);

    normQ = lapack_zlansy(lapack_inf_norm, (enum lapack_uplo_type)PlasmaUpper, minMN, Id, minMN, work);

    printf("============\n");
    printf("Checking the orthogonality of Q \n");
    printf("||Id-Q'*Q||_oo / (N*eps) = %e \n",normQ/(minMN*eps));

    if ( isnan(normQ / (minMN * eps)) || (normQ / (minMN * eps) > 10.0) ) {
        printf("-- Orthogonality is suspicious ! \n");
        info_ortho=1;
    }
    else {
        printf("-- Orthogonality is CORRECT ! \n");
        info_ortho=0;
    }

    free(work); free(Id);

    return info_ortho;
}

/*------------------------------------------------------------
 *  Check the factorization QR
 */

int zcheck_QRfactorization(int M, int N, PLASMA_Complex64_t *A1, PLASMA_Complex64_t *A2, int LDA, PLASMA_Complex64_t *Q)
{
    double Anorm, Rnorm;
    PLASMA_Complex64_t alpha, beta;
    int info_factorization;
    int i,j;
    double eps;

    eps = lapack_dlamch(lapack_eps);

    PLASMA_Complex64_t *Ql       = (PLASMA_Complex64_t *)malloc(M*N*sizeof(PLASMA_Complex64_t));
    PLASMA_Complex64_t *Residual = (PLASMA_Complex64_t *)malloc(M*N*sizeof(PLASMA_Complex64_t));
    double *work              = (double *)malloc(max(M,N)*sizeof(double));

    alpha=1.0;
    beta=0.0;

    if (M >= N) {
        /* Extract the R */
        PLASMA_Complex64_t *R = (PLASMA_Complex64_t *)malloc(N*N*sizeof(PLASMA_Complex64_t));
        memset((void*)R, 0, N*N*sizeof(PLASMA_Complex64_t));
        lapack_zlacpy((enum lapack_uplo_type)PlasmaUpper, M, N, A2, LDA, R, N);

        /* Perform Ql=Q*R */
        memset((void*)Ql, 0, M*N*sizeof(PLASMA_Complex64_t));
        cblas_zgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, N, CBLAS_SADDR(alpha), Q, LDA, R, N, CBLAS_SADDR(beta), Ql, M);
        free(R);
    }
    else {
        /* Extract the L */
        PLASMA_Complex64_t *L = (PLASMA_Complex64_t *)malloc(M*M*sizeof(PLASMA_Complex64_t));
        memset((void*)L, 0, M*M*sizeof(PLASMA_Complex64_t));
        lapack_zlacpy((enum lapack_uplo_type)PlasmaLower, M, N, A2, LDA, L, M);

    /* Perform Ql=LQ */
        memset((void*)Ql, 0, M*N*sizeof(PLASMA_Complex64_t));
        cblas_zgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, M, CBLAS_SADDR(alpha), L, M, Q, LDA, CBLAS_SADDR(beta), Ql, M);
        free(L);
    }

    /* Compute the Residual */
    for (i = 0; i < M; i++)
        for (j = 0 ; j < N; j++)
            Residual[j*M+i] = A1[j*LDA+i]-Ql[j*M+i];

    Rnorm = lapack_zlange(lapack_inf_norm, M, N, Residual, M, work);
    Anorm = lapack_zlange(lapack_inf_norm, M, N, A2, LDA, work);

    if (M >= N) {
        printf("============\n");
        printf("Checking the QR Factorization \n");
        printf("-- ||A-QR||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));
    }
    else {
        printf("============\n");
        printf("Checking the LQ Factorization \n");
        printf("-- ||A-LQ||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));
    }

    if (isnan(Rnorm / (Anorm * N *eps)) || (Rnorm / (Anorm * N * eps) > 10.0) ) {
        printf("-- Factorization is suspicious ! \n");
        info_factorization = 1;
    }
    else {
        printf("-- Factorization is CORRECT ! \n");
        info_factorization = 0;
    }

    free(work); free(Ql); free(Residual);

    return info_factorization;
}

/*------------------------------------------------------------------------
 *  Check the factorization of the matrix A2
 */

int zcheck_LLTfactorization(int N, PLASMA_Complex64_t *A1, PLASMA_Complex64_t *A2, int LDA, int uplo)
{
    double Anorm, Rnorm;
    PLASMA_Complex64_t alpha;
    int info_factorization;
    int i,j;
    double eps;

    eps = lapack_dlamch(lapack_eps);

    PLASMA_Complex64_t *Residual = (PLASMA_Complex64_t *)malloc(N*N*sizeof(PLASMA_Complex64_t));
    PLASMA_Complex64_t *L1       = (PLASMA_Complex64_t *)malloc(N*N*sizeof(PLASMA_Complex64_t));
    PLASMA_Complex64_t *L2       = (PLASMA_Complex64_t *)malloc(N*N*sizeof(PLASMA_Complex64_t));
    double *work              = (double *)malloc(N*sizeof(double));

    memset((void*)L1, 0, N*N*sizeof(PLASMA_Complex64_t));
    memset((void*)L2, 0, N*N*sizeof(PLASMA_Complex64_t));

    alpha= 1.0;

    lapack_zlacpy(lapack_upper_lower, N, N, A1, LDA, Residual, N);

    /* Dealing with L'L or U'U  */
    if (uplo == PlasmaUpper){
        lapack_zlacpy((enum lapack_uplo_type)PlasmaUpper, N, N, A2, LDA, L1, N);
        lapack_zlacpy((enum lapack_uplo_type)PlasmaUpper, N, N, A2, LDA, L2, N);
        cblas_ztrmm(CblasColMajor, CblasLeft, CblasUpper, CblasConjTrans, CblasNonUnit, N, N, CBLAS_SADDR(alpha), L1, N, L2, N);
    }
    else{
        lapack_zlacpy((enum lapack_uplo_type)PlasmaLower, N, N, A2, LDA, L1, N);
        lapack_zlacpy((enum lapack_uplo_type)PlasmaLower, N, N, A2, LDA, L2, N);
        cblas_ztrmm(CblasColMajor, CblasRight, CblasLower, CblasConjTrans, CblasNonUnit, N, N, CBLAS_SADDR(alpha), L1, N, L2, N);
    }

    /* Compute the Residual || A -L'L|| */
    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
           Residual[j*N+i] = L2[j*N+i] - Residual[j*N+i];

    Rnorm = lapack_zlange(lapack_inf_norm, N, N, Residual, N, work);
    Anorm = lapack_zlange(lapack_inf_norm, N, N, A1, LDA, work);

    printf("============\n");
    printf("Checking the Cholesky Factorization \n");
    printf("-- ||L'L-A||_oo/(||A||_oo.N.eps) = %e \n",Rnorm/(Anorm*N*eps));

    if ( isnan(Rnorm/(Anorm*N*eps)) || (Rnorm/(Anorm*N*eps) > 10.0) ){
        printf("-- Factorization is suspicious ! \n");
        info_factorization = 1;
    }
    else{
        printf("-- Factorization is CORRECT ! \n");
        info_factorization = 0;
    }

    free(Residual); free(L1); free(L2); free(work);

    return info_factorization;
}

/*--------------------------------------------------------------
 * Check the gemm
 */
double zcheck_gemm(PLASMA_enum transA, PLASMA_enum transB, int M, int N, int K,
                   PLASMA_Complex64_t alpha, PLASMA_Complex64_t *A, int LDA, 
                   PLASMA_Complex64_t *B, int LDB, 
                   PLASMA_Complex64_t beta, PLASMA_Complex64_t *Cplasma, 
                   PLASMA_Complex64_t *Cref, int LDC,
                   double *Cinitnorm, double *Cplasmanorm, double *Clapacknorm )
{
    PLASMA_Complex64_t beta_const = -1.0;
    double Rnorm;
    double *work = (double *)malloc(max(K,max(M, N))* sizeof(double));

    *Cinitnorm   = lapack_zlange(lapack_inf_norm, M, N, Cref,    LDC, work);
    *Cplasmanorm = lapack_zlange(lapack_inf_norm, M, N, Cplasma, LDC, work);

    cblas_zgemm(CblasColMajor, (enum CBLAS_TRANSPOSE)transA, (enum CBLAS_TRANSPOSE)transB, M, N, K, 
                CBLAS_SADDR(alpha), A, LDA, B, LDB, CBLAS_SADDR(beta), Cref, LDC);

    *Clapacknorm = lapack_zlange(lapack_inf_norm, M, N, Cref, LDC, work);

    cblas_zaxpy(LDC * N, CBLAS_SADDR(beta_const), Cplasma, 1, Cref, 1);

    Rnorm = lapack_zlange(lapack_inf_norm, M, N, Cref, LDC, work);

    free(work);

    return Rnorm;
}

/*--------------------------------------------------------------
 * Check the solution
 */

double zcheck_solution(int M, int N, int NRHS, PLASMA_Complex64_t *A, int LDA, 
                      PLASMA_Complex64_t *B,  PLASMA_Complex64_t *X, int LDB,
                      double *anorm, double *bnorm, double *xnorm )
{
/*     int info_solution; */
    double Rnorm = -1.00;
    PLASMA_Complex64_t zone  =  1.0;
    PLASMA_Complex64_t mzone = -1.0;
    double *work = (double *)malloc(max(M, N)* sizeof(double));

    *anorm = lapack_zlange(lapack_inf_norm, M, N,    A, LDA, work);
    *xnorm = lapack_zlange(lapack_inf_norm, M, NRHS, X, LDB, work);
    *bnorm = lapack_zlange(lapack_inf_norm, N, NRHS, B, LDB, work);

    cblas_zgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, NRHS, N, CBLAS_SADDR(zone), A, LDA, X, LDB, CBLAS_SADDR(mzone), B, LDB);
    
    Rnorm = lapack_zlange(lapack_inf_norm, N, NRHS, B, M, work);

/*     if (M >= N) { */
/*        PLASMA_Complex64_t *Residual = (PLASMA_Complex64_t *)malloc(M*NRHS*sizeof(PLASMA_Complex64_t)); */
/*        memset((void*)Residual, 0, M*NRHS*sizeof(PLASMA_Complex64_t)); */
/*        cblas_zgemm(CblasColMajor, CblasConjTrans, CblasNoTrans, N, NRHS, M, CBLAS_SADDR(alpha), A, LDA, B, LDB, CBLAS_SADDR(beta), Residual, M); */
/*        Rnorm = lapack_zlange(lapack_inf_norm, M, NRHS, Residual, M, work); */
/*        free(Residual); */
/*     } */
/*     else { */
/*        PLASMA_Complex64_t *Residual = (PLASMA_Complex64_t *)malloc(N*NRHS*sizeof(PLASMA_Complex64_t)); */
/*        memset((void*)Residual, 0, N*NRHS*sizeof(PLASMA_Complex64_t)); */
/*        cblas_zgemm(CblasColMajor, CblasConjTrans, CblasNoTrans, N, NRHS, M, CBLAS_SADDR(alpha), A1, LDA, B1, LDB, CBLAS_SADDR(beta), Residual, N); */
/*        Rnorm = lapack_zlange(lapack_inf_norm, N, NRHS, Residual, N, work); */
/*        free(Residual); */
/*     } */

    free(work);

    return Rnorm;
}

