/*////////////////////////////////////////////////////////////////////////////////////////
 *  -- PLASMA --
 *     University of Tennessee
 */
#include "common.h"
#include "auxiliary.h"
#include "allocate.h"
#include "bdl_convert.h"
#include "barrier.h"

#include <string.h>

/*////////////////////////////////////////////////////////////////////////////////////////
 *  Find the least squares solution of an overdetermined problem using QR factorization
 *  Find minimum norm solution of an underdetermined problem using LQ factorization
 *
 *  Differences with LAPACK:
 *  - if one of the dimensions is zero, B is not set to zero
 *  - A and B are not scaled
 *  - if M < N, B(M+1:N,1:NRHS) is not set to zero
 */
int plasma_DGELS(PLASMA_enum trans, int M, int N, int NRHS, double *A,
                 int LDA, double *T, double *B, int LDB)
{
    int NB, MT, NT, NTRHS;
    int status;
    double *Abdl;
    double *Bbdl;
    double *Tbdl;
    double *bdl_mem;
    PLASMA_long size_elems;

    /* Check if initialized */
    if (!plasma_cntrl.initialized) {
        plasma_warning("plasma_DGELS", "PLASMA not initialized");
        return PLASMA_ERR_NOT_INITIALIZED;
    }

    /* Check input arguments */
    if (trans != PlasmaNoTrans) {
        plasma_error("plasma_DGELS", "only PlasmaNoTrans supported");
        return PLASMA_ERR_NOT_SUPPORTED;
    }
    if (M < 0) {
        plasma_error("plasma_DGELS", "illegal value of M");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (N < 0) {
        plasma_error("plasma_DGELS", "illegal value of N");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (NRHS < 0) {
        plasma_error("plasma_DGELS", "illegal value of NRHS");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (LDA < max(1, M)) {
        plasma_error("plasma_DGELS", "illegal value of LDA");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    if (LDB < max(1, max(M, N))) {
        plasma_error("plasma_DGELS", "illegal value of LDB");
        return PLASMA_ERR_ILLEGAL_VALUE;
    }
    /* Quick return - currently NOT equivalent to LAPACK's:
     * CALL DLASET( 'Full', MAX( M, N ), NRHS, ZERO, ZERO, B, LDB ) */
    if (min(M, min(N, NRHS)) == 0)
        return PLASMA_SUCCESS;

    /* Tune NB & IB depending on M, N & NRHS; Set NBNBSIZE */
    status = plasma_tune(PLASMA_TUNE_DGELS, M, N, NRHS);
    if (status != PLASMA_SUCCESS) {
        plasma_error("plasma_DGELS", "plasma_tune() failed");
        return status;
    }

    /* Set MT, NT & NTRHS */
    NB = plasma_cntrl.NB;
    NT = (N%NB==0) ? (N/NB) : (N/NB+1);
    MT = (M%NB==0) ? (M/NB) : (M/NB+1);
    NTRHS = (NRHS%NB==0) ? (NRHS/NB) : (NRHS/NB+1);

    /* If NB larger than NB_max, set NB_max to NB, reallocate WORK & TAU */
    if (plasma_cntrl.NB > plasma_cntrl.NB_max) {
        status = plasma_free_aux_work_tau();
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_free_aux_work_tau() failed");
            return status;
        }
        plasma_cntrl.NB_max = plasma_cntrl.NB;
        status = plasma_alloc_aux_work_tau();
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_alloc_aux_work_tau() failed");
            return status;
        }
    }

    /* If progress table too small, reallocate */
    size_elems = max(MT, NT)*max(NT, NTRHS);
    if (plasma_cntrl.progress_size_elems < size_elems) {
        status = plasma_free_aux_progress();
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_free_aux_progress() failed");
        }
        status = plasma_alloc_aux_progress(size_elems);
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_alloc_aux_progress() failed");
            return status;
        }
    }

    /* Assign arrays to BDL storage */
    bdl_mem = plasma_aux.bdl_mem;
    Abdl = bdl_mem; bdl_mem += MT*NT*plasma_cntrl.NBNBSIZE;
    Tbdl = bdl_mem; bdl_mem += MT*NT*plasma_cntrl.IBNBSIZE;
    Bbdl = bdl_mem; bdl_mem += max(MT, NT)*NTRHS*plasma_cntrl.NBNBSIZE;
    /* If BDL storage too small, reallocate & reassign */
    size_elems = bdl_mem - plasma_aux.bdl_mem;
    if (plasma_cntrl.bdl_size_elems < size_elems) {
        status = plasma_free_aux_bdl();
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_free_aux_bdl() failed");
            return status;
        }
        status = plasma_alloc_aux_bdl(size_elems, PLASMA_TRUE);
        if (status != PLASMA_SUCCESS) {
            plasma_error("plasma_DGELS", "plasma_alloc_aux_bdl() failed");
            return status;
        }
        bdl_mem = plasma_aux.bdl_mem;
        Abdl = bdl_mem; bdl_mem += MT*NT*plasma_cntrl.NBNBSIZE;
        Tbdl = bdl_mem; bdl_mem += MT*NT*plasma_cntrl.IBNBSIZE;
        Bbdl = bdl_mem; bdl_mem += max(MT, NT)*NTRHS*plasma_cntrl.NBNBSIZE;
    }

    if (M >= N) {
        /* Convert A from LAPACK to BDL */
        /* Set arguments */
        plasma_args.F77 = A;
        plasma_args.A = Abdl;
        plasma_args.M = M;
        plasma_args.N = N;
        plasma_args.LDA = LDA;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.MT = MT;
        plasma_args.NT = NT;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_F77_TO_BDL;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_lapack_to_bdl(plasma_args.F77, plasma_args.A, plasma_args.M, plasma_args.N,
                             plasma_args.LDA, plasma_args.NB, plasma_args.MT, plasma_args.NT,
                             plasma_args.NBNBSIZE, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Convert B from LAPACK to BDL */
        /* Set arguments */
        plasma_args.F77 = B;
        plasma_args.A = Bbdl;
        plasma_args.M = M;
        plasma_args.N = NRHS;
        plasma_args.LDA = LDB;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.MT = MT;
        plasma_args.NT = NTRHS;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_F77_TO_BDL;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_lapack_to_bdl(plasma_args.F77, plasma_args.A, plasma_args.M, plasma_args.N,
                             plasma_args.LDA, plasma_args.NB, plasma_args.MT, plasma_args.NT,
                             plasma_args.NBNBSIZE, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Use QR factorization */
        /* Call parallel DGEQRF */
        /* Set arguments */
        plasma_args.M = M;
        plasma_args.N = N;
        plasma_args.A = Abdl;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        plasma_args.IBNBSIZE = plasma_cntrl.IBNBSIZE;
        plasma_args.IB = plasma_cntrl.IB;
        plasma_args.MT = MT;
        plasma_args.NT = NT;
        plasma_args.T = Tbdl;
        /* Clear progress table */
        plasma_clear_aux_progress(MT*NT, -1);
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_DGEQRF;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_pDGEQRF(plasma_args.M, plasma_args.N, plasma_args.A, plasma_args.NB,
                    plasma_args.NBNBSIZE, plasma_args.IBNBSIZE, plasma_args.IB,
                    plasma_args.MT, plasma_args.NT, plasma_args.T, &plasma_args.INFO,
                    plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Return T to the user */
        memcpy(T, Tbdl, MT*NT*plasma_args.IBNBSIZE*sizeof(double));

        /* Call parallel DORMQR */
        /* Set arguments */
        plasma_args.M = M;
        plasma_args.N = N;
        plasma_args.NRHS = NRHS;
        plasma_args.A = Abdl;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        plasma_args.IBNBSIZE = plasma_cntrl.IBNBSIZE;
        plasma_args.IB = plasma_cntrl.IB;
        plasma_args.MT = MT;
        plasma_args.NTRHS = NTRHS;
        plasma_args.NT = NT;
        plasma_args.T = Tbdl;
        plasma_args.B = Bbdl;
        /* Clear progress table */
        plasma_clear_aux_progress(MT*NTRHS, -1);
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_DORMQR;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_pDORMQR(plasma_args.M, plasma_args.NRHS, plasma_args.N, plasma_args.A,
                    plasma_args.NB, plasma_args.NBNBSIZE, plasma_args.IBNBSIZE,
                    plasma_args.IB, plasma_args.MT, plasma_args.NTRHS, plasma_args.NT,
                    plasma_args.T, plasma_args.B, &plasma_args.INFO, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Call parallel DTRSM */
        /* Set arguments */
        plasma_args.uplo = PlasmaUpper;
        plasma_args.trans = PlasmaNoTrans;
        plasma_args.diag = PlasmaNonUnit;
        plasma_args.N = N;
        plasma_args.NRHS = NRHS;
        plasma_args.A = Abdl;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        plasma_args.NT = NT;
        plasma_args.MT = MT;
        plasma_args.B = Bbdl;
        plasma_args.MTB = MT;
        plasma_args.NTRHS = NTRHS;
        /* Clear progress table */
        plasma_clear_aux_progress(NT*NTRHS, -1);
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_DTRSM;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_pDTRSM(PlasmaLeft, plasma_args.uplo, plasma_args.trans, plasma_args.diag,
                      plasma_args.N, plasma_args.NRHS, 1.0, plasma_args.A, plasma_args.NB,
                      plasma_args.NBNBSIZE, plasma_args.NT, plasma_args.MT, plasma_args.B,
                      plasma_args.MTB, plasma_args.NTRHS, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Convert A from BDL to LAPACK */
        /* Set arguments */
        plasma_args.A = Abdl;
        plasma_args.F77 = A;
        plasma_args.M = M;
        plasma_args.N = N;
        plasma_args.LDA = LDA;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.MT = MT;
        plasma_args.NT = NT;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_BDL_TO_F77;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_bdl_to_lapack(plasma_args.A, plasma_args.F77, plasma_args.M, plasma_args.N,
                             plasma_args.LDA, plasma_args.NB, plasma_args.MT, plasma_args.NT,
                             plasma_args.NBNBSIZE, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Convert B from BDL to LAPACK */
        /* Set arguments */
        plasma_args.A = Bbdl;
        plasma_args.F77 = B;
        plasma_args.M = M;
        plasma_args.N = NRHS;
        plasma_args.LDA = LDB;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.MT = MT;
        plasma_args.NT = NTRHS;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_BDL_TO_F77;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_bdl_to_lapack(plasma_args.A, plasma_args.F77, plasma_args.M, plasma_args.N,
                             plasma_args.LDA, plasma_args.NB, plasma_args.MT, plasma_args.NT,
                             plasma_args.NBNBSIZE, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);
    }
    else {
        /* Convert arrays from LAPACK to BDL */
        plasma_lapack_to_bdl(A, Abdl, M, N, LDA, plasma_cntrl.NB,
                            MT, NT, plasma_cntrl.NBNBSIZE, 1, 0);
        plasma_lapack_to_bdl(B, Bbdl, N, NRHS, LDB, plasma_cntrl.NB,
                             NT, NTRHS, plasma_cntrl.NBNBSIZE, 1, 0);

        /* Use LQ factorization */
        /* Call parallel DGEQRF */
        /* Set arguments */
        plasma_args.M = M;
        plasma_args.N = N;
        plasma_args.A = Abdl;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        plasma_args.IBNBSIZE = plasma_cntrl.IBNBSIZE;
        plasma_args.IB = plasma_cntrl.IB;
        plasma_args.MT = MT;
        plasma_args.NT = NT;
        plasma_args.T = Tbdl;
        /* Clear progress table */
        plasma_clear_aux_progress(MT*NT, -1);
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_DGELQF;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_pDGELQF(plasma_args.M, plasma_args.N, plasma_args.A, plasma_args.NB,
                    plasma_args.NBNBSIZE, plasma_args.IBNBSIZE, plasma_args.IB,
                    plasma_args.MT, plasma_args.NT, plasma_args.T, &plasma_args.INFO,
                    plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Call parallel DTRSM */
        /* Set arguments */
        plasma_args.uplo = PlasmaLower;
        plasma_args.trans = PlasmaNoTrans;
        plasma_args.diag = PlasmaNonUnit;
        plasma_args.N = M;
        plasma_args.NRHS = NRHS;
        plasma_args.A = Abdl;
        plasma_args.NB = plasma_cntrl.NB;
        plasma_args.NBNBSIZE = plasma_cntrl.NBNBSIZE;
        plasma_args.NT = MT;
        plasma_args.MT = MT;
        plasma_args.B = Bbdl;
        plasma_args.MTB = NT;
        plasma_args.NTRHS = NTRHS;
        /* Clear progress table */
        plasma_clear_aux_progress(MT*NTRHS, -1);
        /* Signal workers */
        pthread_mutex_lock(&plasma_cntrl.action_mutex);
        plasma_cntrl.action = PLASMA_ACT_DTRSM;
        pthread_mutex_unlock(&plasma_cntrl.action_mutex);
        pthread_cond_broadcast(&plasma_cntrl.action_condt);
        /* Call for master */
        plasma_barrier(0, plasma_cntrl.cores_num);
        plasma_cntrl.action = PLASMA_ACT_STAND_BY;
        plasma_pDTRSM(PlasmaLeft, plasma_args.uplo, plasma_args.trans, plasma_args.diag,
                      plasma_args.N, plasma_args.NRHS, 1.0, plasma_args.A, plasma_args.NB,
                      plasma_args.NBNBSIZE, plasma_args.NT, plasma_args.MT, plasma_args.B,
                      plasma_args.MTB, plasma_args.NTRHS, plasma_cntrl.cores_num, 0);
        plasma_barrier(0, plasma_cntrl.cores_num);

        /* Here LAPACK sets B(M+1:N,1:NRHS) to zero
         */

        /* Call parallel DORMQR */

        /* Convert arrays from BDL to LAPACK */
        plasma_bdl_to_lapack(Abdl, A, M, N, LDA, plasma_cntrl.NB,
                             MT, NT, plasma_cntrl.NBNBSIZE, 1, 0);
        plasma_bdl_to_lapack(Bbdl, B, N, NRHS, LDB, plasma_cntrl.NB,
                             NT, NTRHS, plasma_cntrl.NBNBSIZE, 1, 0);
    }

    return PLASMA_SUCCESS;
}
