/* ///////////////////////////////////////////////////////////////////////////////////////
 *  -- PLASMA --
 *     University of Tennessee
 */
#include <stdlib.h>
#include "common.h"
#include <math.h>
#include "lapack.h"
/* ///////////////////////////// P /// L /// A /// S /// M /// A /////////////////////////////// */
/* ///                    PLASMA computational routine (version 2.1.0)                       ///
 * ///                    Release Date: November, 15th 2009                                  ///
 * ///                    PLASMA is a software package provided by Univ. of Tennessee,       ///
 * ///                    Univ. of California Berkeley and Univ. of Colorado Denver          /// */

/* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
// PLASMA_dsgesv - Computes the solution to a system of linear equations A * X = B,
// where A is an N-by-N matrix and X and B are N-by-NRHS matrices.
//
// PLASMA_dsgesv first attempts to factorize the matrix in COMPLEX and use this
// factorization within an iterative refinement procedure to produce a
// solution with COMPLEX*16 normwise backward error quality (see below).
// If the approach fails the method switches to a COMPLEX*16
// factorization and solve.
//
// The iterative refinement is not going to be a winning strategy if
// the ratio COMPLEX performance over COMPLEX*16 performance is too
// small. A reasonable strategy should take the number of right-hand
// sides and the size of the matrix into account. This might be done
// with a call to ILAENV in the future. Up to now, we always try
// iterative refinement.
//
// The iterative refinement process is stopped if ITER > ITERMAX or
// for all the RHS we have: RNRM < N*XNRM*ANRM*EPS*BWDMAX
// where
//
// - ITER is the number of the current iteration in the iterative refinement process
// - RNRM is the infinity-norm of the residual
// - XNRM is the infinity-norm of the solution
// - ANRM is the infinity-operator-norm of the matrix A
// - EPS is the machine epsilon returned by DLAMCH('Epsilon').
//
// Actually, in its current state (PLASMA 2.1.0), the test is slightly relaxed.
//
// The values ITERMAX and BWDMAX are fixed to 30 and 1.0D+00 respectively.

/* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
// N        int (IN)
//          The number of linear equations, i.e., the order of the matrix A. N >= 0.
//
// NRHS     int (IN)
//          The number of right hand sides, i.e., the number of columns of the matrix B.
//          NRHS >= 0.
//
// A        double* (IN)
//          The N-by-N coefficient matrix A. This matrix is not modified.
//
// LDA      int (IN)
//          The leading dimension of the array A. LDA >= max(1,N).
//
// B        double* (IN)
//          The N-by-NRHS matrix of right hand side matrix B.
//
// LDB      int (IN)
//          The leading dimension of the array B. LDB >= max(1,N).
//
// X        double* (OUT)
//          If return value = 0, the N-by-NRHS solution matrix X.
//
// LDX      int (IN)
//          The leading dimension of the array B. LDX >= max(1,N).
//
// ITER     int* (OUT)is the number of the current iteration in the iterative refinement process


/* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
//          = 0: successful exit
//          < 0: if -i, the i-th argument had an illegal value
//          > 0: if i, U(i,i) is exactly zero. The factorization has been completed,
//               but the factor U is exactly singular, so the solution could not be computed.

/* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */

#define PLASMA_dlag2s(_descA, _descB) plasma_parallel_call_2(plasma_pdlag2s, PLASMA_desc, _descA, PLASMA_desc, _descB)
#define PLASMA_slag2d(_descA, _descB) plasma_parallel_call_2(plasma_pslag2d, PLASMA_desc, _descA, PLASMA_desc, _descB)
#define PLASMA_dlange(_norm, _descA, _result, _work, _counter) _result = 0;	\
                                                     plasma_parallel_call_3(plasma_pdlange, char, _norm, PLASMA_desc, _descA, double*, _work);\
                                                     for (_counter = 0; _counter < PLASMA_SIZE; _counter++){if (((double *)_work)[_counter] > _result) _result = ((double *)_work)[_counter];}

#define PLASMA_dlacpy(_descA, _descB) plasma_parallel_call_2(plasma_pdlacpy, PLASMA_desc, _descA, PLASMA_desc, _descB)
#define PLASMA_daxpy(_alpha, _descA, _descB) plasma_parallel_call_3(plasma_pdaxpy, double, _alpha, PLASMA_desc, _descA, PLASMA_desc, _descB)

int PLASMA_dsgesv(int N, int NRHS, double *A, int LDA, double *B, int LDB, double *X, int LDX, int *ITER)
{
    int NB, NT, NTRHS;
    int status;
    double *Abdl;
    double *Lbdl;
    double *Bbdl;
    double *Xbdl;
    plasma_context_t *plasma;
    double *L;
    int *IPIV;

    plasma = plasma_context_self();
    if (plasma == NULL) {
        plasma_fatal_error("PLASMA_dsgesv", "PLASMA not initialized");
        return PLASMA_ERR_NOT_INITIALIZED;
    }
    /* Check input arguments */
    if (N < 0) {
        plasma_error("PLASMA_dsgesv", "illegal value of N");
        return -1;
    }
    if (NRHS < 0) {
        plasma_error("PLASMA_dsgesv", "illegal value of NRHS");
        return -2;
    }
    if (LDA < max(1, N)) {
        plasma_error("PLASMA_dsgesv", "illegal value of LDA");
        return -4;
    }
    if (LDB < max(1, N)) {
        plasma_error("PLASMA_dsgesv", "illegal value of LDB");
        return -8;
    }
    if (LDX < max(1, N)) {
        plasma_error("PLASMA_dsgesv", "illegal value of LDX");
        return -10;
    }
    /* Quick return */
    if (min(N, NRHS) == 0)
        return PLASMA_SUCCESS;

    /* Tune NB & IB depending on M, N & NRHS; Set NBNBSIZE */
    status = plasma_tune(PLASMA_TUNE_DSGESV, N, N, NRHS);
    if (status != PLASMA_SUCCESS) {
        plasma_error("PLASMA_dsgesv", "plasma_tune() failed");
        return status;
    }

    /* Set NT & NTRHS */
    NB = PLASMA_NB;
    NT = (N%NB==0) ? (N/NB) : (N/NB+1);
    NTRHS = (NRHS%NB==0) ? (NRHS/NB) : (NRHS/NB+1);

    /* DOUBLE PRECISION INITIALIZATION */
    /* Allocate memory for double precision matrices in block layout */
    Abdl = (double *)plasma_shared_alloc(plasma, NT*NT*PLASMA_NBNBSIZE, PlasmaRealDouble);
    Lbdl = (double *)plasma_shared_alloc(plasma, NT*NT*PLASMA_IBNBSIZE, PlasmaRealDouble);
    Bbdl = (double *)plasma_shared_alloc(plasma, NT*NTRHS*PLASMA_NBNBSIZE, PlasmaRealDouble);
    Xbdl = (double *)plasma_shared_alloc(plasma, NT*NTRHS*PLASMA_NBNBSIZE, PlasmaRealDouble);
    if (Abdl == NULL || Lbdl == NULL || Xbdl == NULL) {
        plasma_error("PLASMA_dsgesv", "plasma_shared_alloc() failed");
        plasma_shared_free(plasma, Abdl);
        plasma_shared_free(plasma, Lbdl);
        plasma_shared_free(plasma, Bbdl);
        plasma_shared_free(plasma, Xbdl);
        return PLASMA_ERR_OUT_OF_RESOURCES;
    }

    PLASMA_desc descA = plasma_desc_init(
        Abdl, PlasmaRealDouble,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, N, 0, 0, N, N);

    PLASMA_desc descL = plasma_desc_init(
        Lbdl, PlasmaRealDouble,
        PLASMA_IB, PLASMA_NB, PLASMA_IBNBSIZE,
        N, N, 0, 0, N, N);

    PLASMA_desc descB = plasma_desc_init(
        Bbdl, PlasmaRealDouble,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, NRHS, 0, 0, N, NRHS);

    PLASMA_desc descX = plasma_desc_init(
        Xbdl, PlasmaRealDouble,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, NRHS, 0, 0, N, NRHS);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        double*, A,
        int, LDA,
        PLASMA_desc, descA);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        double*, B,
        int, LDB,
        PLASMA_desc, descB);

    plasma_parallel_call_3(plasma_lapack_to_tile,
        double*, X,
        int, LDX,
        PLASMA_desc, descX);

    /* Allocate workspace */
    PLASMA_Alloc_Workspace_dgesv(N, &L, &IPIV);

    /* Call the native interface */
    status = PLASMA_dsgesv_Tile(&descA, &descL, IPIV, &descB, &descX, ITER);

    //plasma_parallel_call_3(plasma_tile_to_lapack,
    //    PLASMA_desc, descA,
    //    double*, A,
    //    int, LDA);

    plasma_parallel_call_3(plasma_tile_to_lapack,
        PLASMA_desc, descX,
        double*, X,
        int, LDX);

    plasma_shared_free(plasma, Abdl);
    plasma_shared_free(plasma, Lbdl);
    plasma_shared_free(plasma, Bbdl);
    plasma_shared_free(plasma, Xbdl);
    free(L);
    free(IPIV);
    return PLASMA_INFO;
}



/* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
// PLASMA_dsgesv_Tile - Computes the solution to a system of linear equations A * X = B,
// where A is an N-by-N matrix and X and B are N-by-NRHS matrices.
// All matrices are passed through descriptors. All dimensions are taken from the descriptors.
//
// PLASMA_dsgesv_Tile first attempts to factorize the matrix in COMPLEX and use this
// factorization within an iterative refinement procedure to produce a
// solution with COMPLEX*16 normwise backward error quality (see below).
// If the approach fails the method switches to a COMPLEX*16
// factorization and solve.
//
// The iterative refinement is not going to be a winning strategy if
// the ratio COMPLEX performance over COMPLEX*16 performance is too
// small. A reasonable strategy should take the number of right-hand
// sides and the size of the matrix into account. This might be done
// with a call to ILAENV in the future. Up to now, we always try
// iterative refinement.
//
// The iterative refinement process is stopped if ITER > ITERMAX or
// for all the RHS we have: RNRM < N*XNRM*ANRM*EPS*BWDMAX
// where
//
// - ITER is the number of the current iteration in the iterative refinement process
// - RNRM is the infinity-norm of the residual
// - XNRM is the infinity-norm of the solution
// - ANRM is the infinity-operator-norm of the matrix A
// - EPS is the machine epsilon returned by DLAMCH('Epsilon').
//
// Actually, in his current state (PLASMA 2.1.0), the test is slightly relaxed.
//
// The values ITERMAX and BWDMAX are fixed to 30 and 1.0D+00 respectively.

/* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
// A        double* (In or INOUT)
//          On entry, the N-by-N coefficient matrix A.
//          - if the iterative refinement converged, A is not modified;
//          - otherwise, it falled backed to double precision solution,
//          and then A contains the tile L and U factors from the factorization (not equivalent to LAPACK).
//
// L        double* (NODEP or OUT)
//          On exit:
//          - if the iterative refinement converged, L is not modified;
//          - otherwise, it falled backed to double precision solution,
//          and then L is an auxiliary factorization data, related to the tile L factor,
//          necessary to solve the system of equations (not equivalent to LAPACK).
//
// IPIV     int* (OUT)
//          On exit, the pivot indices that define the permutations (not equivalent to LAPACK).
//
// B        double* (INOUT)
//          On entry, the N-by-NRHS matrix of right hand side matrix B.
//          On exit, if return value = 0, the N-by-NRHS solution matrix X.

/* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
//          = 0: successful exit
//          > 0: if i, U(i,i) is exactly zero. The factorization has been completed,
//               but the factor U is exactly singular, so the solution could not be computed.

/* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */
int PLASMA_dsgesv_Tile(PLASMA_desc *A, PLASMA_desc *L, int *IPIV, PLASMA_desc *B, PLASMA_desc *X, int *ITER)
{
    int N, NRHS, NB, NT, NTRHS;
    PLASMA_desc descA = *A;
    PLASMA_desc descL = *L;
    PLASMA_desc descB = *B;
    PLASMA_desc descX = *X;
    float *SAbdl;
    float *SLbdl;
    float *SXbdl;
    double *Rbdl;
    plasma_context_t *plasma;
    int counter;
    double *work;

    const int itermax = 30;
    const double bwdmax = 1.0;
    const double negone = -1.0;
    const double one = 1.0;
    char norm='I';
    int iiter;
    double Anorm, cte, eps, Rnorm, Xnorm;
    *ITER=0;

    plasma = plasma_context_self();
    if (plasma == NULL) {
        plasma_fatal_error("PLASMA_dsgesv_Tile", "PLASMA not initialized");
        return PLASMA_ERR_NOT_INITIALIZED;
    }
    /* Check descriptors for correctness */
    if (plasma_desc_check(&descA) != PLASMA_SUCCESS) {
        plasma_error("PLASMA_dsgesv_Tile", "invalid first descriptor");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (plasma_desc_check(&descL) != PLASMA_SUCCESS) {
        plasma_error("PLASMA_dsgesv_Tile", "invalid second descriptor");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (plasma_desc_check(&descB) != PLASMA_SUCCESS) {
        plasma_error("PLASMA_dsgesv_Tile", "invalid third descriptor");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (plasma_desc_check(&descX) != PLASMA_SUCCESS) {
        plasma_error("PLASMA_dsgesv_Tile", "invalid fourth descriptor");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    /* Check input arguments */
    if (descA.nb != descA.mb || descB.nb != descB.mb || descX.nb != descX.mb) {
        plasma_error("PLASMA_dsgesv_Tile", "only square tiles supported");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }

    /* Set N, NRHS, NT & NTRHS */
    N = descA.lm;
    NRHS = descB.ln;
    NB = PLASMA_NB;
    NT = (N%NB==0) ? (N/NB) : (N/NB+1);
    NTRHS = (NRHS%NB==0) ? (NRHS/NB) : (NRHS/NB+1);

    work = (double *)plasma_shared_alloc(plasma, PLASMA_SIZE, PlasmaRealDouble);
    if (work == NULL) {
        plasma_error("PLASMA_dsgesv", "plasma_shared_alloc() failed");
        plasma_shared_free(plasma, work);
        return PLASMA_ERR_OUT_OF_RESOURCES;
    }

    Rbdl = (double *)plasma_shared_alloc(plasma, NT*NTRHS*PLASMA_NBNBSIZE, PlasmaRealDouble);
    if (Rbdl == NULL) {
        plasma_error("PLASMA_dsgesv", "plasma_shared_alloc() failed");
        plasma_shared_free(plasma, Rbdl);
        return PLASMA_ERR_OUT_OF_RESOURCES;
    }

    PLASMA_desc descR = plasma_desc_init(
        Rbdl, PlasmaRealDouble,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, NRHS, 0, 0, N, NRHS);

    /* Allocate memory for single precision matrices in block layout */
    SAbdl = (float *)plasma_shared_alloc(plasma, NT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
    SLbdl = (float *)plasma_shared_alloc(plasma, NT*NT*PLASMA_IBNBSIZE, PlasmaRealFloat);
    SXbdl = (float *)plasma_shared_alloc(plasma, NT*NTRHS*PLASMA_NBNBSIZE, PlasmaRealFloat);
    if (SAbdl == NULL || SLbdl == NULL || SXbdl == NULL) {
        plasma_error("PLASMA_dsgesv", "plasma_shared_alloc() failed");
        plasma_shared_free(plasma, SAbdl);
        plasma_shared_free(plasma, SLbdl);
        plasma_shared_free(plasma, SXbdl);
        return PLASMA_ERR_OUT_OF_RESOURCES;
    }

    PLASMA_desc descSA = plasma_desc_init(
        SAbdl, PlasmaRealFloat,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, N, 0, 0, N, N);

    PLASMA_desc descSL = plasma_desc_init(
        SLbdl, PlasmaRealFloat,
        PLASMA_IB, PLASMA_NB, PLASMA_IBNBSIZE,
        N, N, 0, 0, N, N);

    PLASMA_desc descSX = plasma_desc_init(
        SXbdl, PlasmaRealFloat,
        PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
        N, NRHS, 0, 0, N, NRHS);

    /* Compute some constants */
    PLASMA_dlange(norm, descA, Anorm, work, counter);
    eps = dlamch("Epsilon");
    cte = Anorm*eps*((double) N)*bwdmax;

    /* Convert B from double precision to single precision and store
       the result in SX. */

    PLASMA_dlag2s(descB, descSX);

    /* Convert A from double precision to single precision and store
       the result in SA. */

    PLASMA_dlag2s(descA, descSA);

    /* Clear IPIV and Lbdl */
    plasma_memzero(IPIV, NT*NT*PLASMA_NB, PlasmaInteger);
    plasma_memzero(SLbdl, NT*NT*PLASMA_IBNBSIZE, PlasmaRealFloat);

    /* Set INFO to ZERO */
    PLASMA_INFO = PLASMA_SUCCESS;

    /* Compute the LU factorization of SA */
    plasma_parallel_call_3(plasma_psgetrf,
        PLASMA_desc, descSA,
        PLASMA_desc, descSL,
        int*, IPIV);

    if (PLASMA_INFO == PLASMA_SUCCESS) {

        /* Solve the system SA*SX = SB */

        /* Forward substitution */
        plasma_parallel_call_4(plasma_pstrsmpl,
            PLASMA_desc, descSA,
            PLASMA_desc, descSX,
            PLASMA_desc, descSL,
            int*, IPIV);

        /* Backward substitution */
        plasma_parallel_call_7(plasma_pstrsm,
            PLASMA_enum, PlasmaLeft,
            PLASMA_enum, PlasmaUpper,
            PLASMA_enum, PlasmaNoTrans,
            PLASMA_enum, PlasmaNonUnit,
            float, 1.0,
            PLASMA_desc, descSA,
            PLASMA_desc, descSX);

    } else {
        plasma_shared_free(plasma, SAbdl);
        plasma_shared_free(plasma, SLbdl);
        plasma_shared_free(plasma, SXbdl);
        plasma_shared_free(plasma, Rbdl);
        return PLASMA_INFO;
    }

    /* Convert SX back to double precision */
    PLASMA_slag2d(descSX, descX);

    /* Compute R = B - AX. */
    PLASMA_dlacpy(descB,descR);
    plasma_parallel_call_7(plasma_pdgemm,
        PLASMA_enum, PlasmaNoTrans,
        PLASMA_enum, PlasmaNoTrans,
        double, negone,
        PLASMA_desc, descA,
        PLASMA_desc, descX,
        double, one,
        PLASMA_desc, descR);

    /* Check whether the NRHS normwise backward error satisfies the
       stopping criterion. If yes return. Note that ITER=0 (already set). */
    PLASMA_dlange(norm, descX, Xnorm, work, counter);
    PLASMA_dlange(norm, descR, Rnorm, work, counter);

    if (Rnorm < Xnorm * cte){
      /* The NRHS normwise backward errors satisfy the
         stopping criterion. We are good to exit. */
      plasma_shared_free(plasma, SAbdl);
      plasma_shared_free(plasma, SLbdl);
      plasma_shared_free(plasma, SXbdl);
      plasma_shared_free(plasma, Rbdl);
      return PLASMA_INFO;
    }

    /* Iterative refinement */
    for (iiter = 0; iiter < itermax; iiter++){

      /* Convert R from double precision to single precision
         and store the result in SX. */
      PLASMA_dlag2s(descR, descSX);

      /* Solve the system SA*SX = SR */

      /* Forward substitution */
      plasma_parallel_call_4(plasma_pstrsmpl,
            PLASMA_desc, descSA,
            PLASMA_desc, descSX,
            PLASMA_desc, descSL,
            int*, IPIV);

      /* Backward substitution */
      plasma_parallel_call_7(plasma_pstrsm,
            PLASMA_enum, PlasmaLeft,
            PLASMA_enum, PlasmaUpper,
            PLASMA_enum, PlasmaNoTrans,
            PLASMA_enum, PlasmaNonUnit,
            float, 1.0,
            PLASMA_desc, descSA,
            PLASMA_desc, descSX);

      /* Convert SX back to double precision and update the current
         iterate. */
      PLASMA_slag2d(descSX, descR);
      PLASMA_daxpy(one, descR, descX);


      /* Compute R = B - AX. */
      PLASMA_dlacpy(descB,descR);
      plasma_parallel_call_7(plasma_pdgemm,
        PLASMA_enum, PlasmaNoTrans,
        PLASMA_enum, PlasmaNoTrans,
        double, negone,
        PLASMA_desc, descA,
        PLASMA_desc, descX,
        double, one,
        PLASMA_desc, descR);

      /* Check whether the NRHS normwise backward errors satisfy the
         stopping criterion. If yes, set ITER=IITER>0 and return. */
      PLASMA_dlange(norm, descX, Xnorm, work, counter);
      PLASMA_dlange(norm, descR, Rnorm, work, counter);

      if (Rnorm < Xnorm * cte){
        /* The NRHS normwise backward errors satisfy the
           stopping criterion. We are good to exit. */
        *ITER = iiter;

        plasma_shared_free(plasma, SAbdl);
        plasma_shared_free(plasma, SLbdl);
        plasma_shared_free(plasma, SXbdl);
        plasma_shared_free(plasma, Rbdl);
        return PLASMA_INFO;
      }

    }

    /* We have performed ITER=itermax iterations and never satisified
       the stopping criterion, set up the ITER flag accordingly and
       follow up on double precision routine. */

    *ITER = -itermax - 1;

    plasma_shared_free(plasma, SAbdl);
    plasma_shared_free(plasma, SLbdl);
    plasma_shared_free(plasma, SXbdl);
    plasma_shared_free(plasma, Rbdl);

    /* Single-precision iterative refinement failed to converge to a
       satisfactory solution, so we resort to double precision. */

    /* Clear IPIV and Lbdl */
    plasma_memzero(IPIV, NT*NT*PLASMA_NB, PlasmaInteger);
    plasma_memzero(((double *)(descL.mat)), NT*NT*PLASMA_IBNBSIZE, PlasmaRealDouble);

    /* Set INFO to ZERO */
    PLASMA_INFO = PLASMA_SUCCESS;

    plasma_parallel_call_3(plasma_pdgetrf,
        PLASMA_desc, descA,
        PLASMA_desc, descL,
        int*, IPIV);

    if (PLASMA_INFO == PLASMA_SUCCESS)
    {
        plasma_parallel_call_4(plasma_pdtrsmpl,
            PLASMA_desc, descA,
            PLASMA_desc, descX,
            PLASMA_desc, descL,
            int*, IPIV);

        plasma_parallel_call_7(plasma_pdtrsm,
            PLASMA_enum, PlasmaLeft,
            PLASMA_enum, PlasmaUpper,
            PLASMA_enum, PlasmaNoTrans,
            PLASMA_enum, PlasmaNonUnit,
            double, 1.0,
            PLASMA_desc, descA,
            PLASMA_desc, descX);
    }

    return PLASMA_INFO;
}
