/**
 *
 * @file core_zgetrf_rectil.c
 *
 *  PLASMA core_blas kernel
 *  PLASMA is a software package provided by Univ. of Tennessee,
 *  Univ. of California Berkeley and Univ. of Colorado Denver
 *
 * @version 2.4.0
 * @author Hatem Ltaief
 * @author Mathieu Faverge
 * @author Piotr Luszczek
 * @date 2009-11-15
 *
 * @precisions normal z -> c d s
 * 
 **/

#include <math.h>
#include <cblas.h>
#include <lapacke.h>
#include "common.h"

#define A(m, n) BLKADDR(A, PLASMA_Complex64_t, m, n)

static void CORE_zbarrier_thread(const int thidx, const int thcnt);
static void CORE_zamax1_thread(const PLASMA_Complex64_t localamx, 
                               const int thidx, const int thcnt, 
                               int *thwinner, PLASMA_Complex64_t *globalamx,
                               const int pividx, int *ipiv);

/***************************************************************************//**
 *
 * @ingroup CORE_PLASMA_Complex64_t
 *
 *  CORE_zgetf2 computes an LU factorization of a general M-by-N matrix A
 *  using partial pivoting with row interchanges.
 *
 *  WARNING: You cannot call this kernel on different matrices at the same time
 *
 *  The factorization has the form
 *
 *    A = P * L * U
 *
 *  where P is a permutation matrix, L is lower triangular with unit
 *  diagonal elements (lower trapezoidal if m > n), and U is upper
 *  triangular (upper trapezoidal if m < n).
 *
 *  This is the right-looking LAPACK Level 2 BLAS version of the algorithm.
 *
 *******************************************************************************
 *
 *  @param[in] M
 *          The number of rows of the matrix A.  M >= 0.
 *
 *  @param[in] N
 *          The number of columns of the matrix A.  N >= 0.
 *
 *  @param[in,out] A
 *          On entry, the m by n matrix to be factored.
 *          On exit, the factors L and U from the factorization
 *          A = P*L*U; the unit diagonal elements of L are not stored.
 *
 *  @param[in] LDA
 *          The leading dimension of the array A.  LDA >= max(1,M).
 *
 *  @param[out] IPIV
 *          The pivot indices; for 1 <= i <= min(M,N), row i of the
 *          matrix was interchanged with row IPIV(i).
 *
 *  @param[out] INFO
 *          = k if U(k,k) is exactly zero. The factorization
 *               has been completed, but the factor U is exactly
 *               singular, and division by zero will occur if it is used
 *               to solve a system of equations.
 *******************************************************************************
 *
 * @return
 *          \retval PLASMA_SUCCESS successful exit
 *          \retval -k, the k-th argument had an illegal value
 *
 */

#define AMAX1BUF_SIZE (48 << 1)

/* 48 threads should be enough for everybody */
static volatile PLASMA_Complex64_t CORE_zamax1buf[AMAX1BUF_SIZE]; 
static double sfmin;

void 
CORE_zgetrf_rectil_init(void) { 
    int i;
    for (i = 0; i < AMAX1BUF_SIZE; ++i) CORE_zamax1buf[i] = -1.0;
    sfmin =  LAPACKE_dlamch_work('S');
}

static void
CORE_zamax1_thread(PLASMA_Complex64_t localamx, int thidx, int thcnt, int *thwinner, 
                   PLASMA_Complex64_t *globalamx, int pividx, int *ipiv) {
    if (thidx == 0) {
        int i, j = 0;
        PLASMA_Complex64_t curval = localamx, tmp;
        double curamx = cabs(localamx);
        
        /* make sure everybody filled in their value */
        for (i = 1; i < thcnt; ++i) {
            while (CORE_zamax1buf[i << 1] == -1.0) { /* wait for thread i to store its value */
            }
        }
        
        /* better not fuse the loop above and below to make sure data is sync'd */
        
        for (i = 1; i < thcnt; ++i) {
            tmp = CORE_zamax1buf[ (i << 1) + 1];
            if (cabs(tmp) > curamx) {
                curamx = cabs(tmp);
                curval = tmp;
                j = i;
            }
        }
        
        if (0 == j)
            ipiv[0] = pividx;
        
        /* make sure everybody knows the amax value */
        for (i = 1; i < thcnt; ++i)
            CORE_zamax1buf[ (i << 1) + 1] = curval;
        
        CORE_zamax1buf[0] = -j - 2.0; /* set the index of the winning thread */
        
        *thwinner = j;
        *globalamx = curval;
        
        for (i = 1; i < thcnt; ++i)
            CORE_zamax1buf[i << 1] = -3.0;
        
        /* make sure everybody read the max value */
        for (i = 1; i < thcnt; ++i) {
            while (CORE_zamax1buf[i << 1] != -1.0) {
            }
        }
        
        CORE_zamax1buf[0] = -1.0;
    } else {
        CORE_zamax1buf[(thidx << 1) + 1] = localamx;
        CORE_zamax1buf[thidx << 1] = -2.0;  /* announce to thread 0 that local amax was stored */
        while (CORE_zamax1buf[0] == -1.0) { /* wait for thread 0 to finish calculating the global amax */
        }
        while (CORE_zamax1buf[thidx << 1] != -3.0) { /* wait for thread 0 to store amax */
        }
        *globalamx = CORE_zamax1buf[(thidx << 1) + 1]; /* read the amax from the location adjacent to the one in the above loop */
        *thwinner = -CORE_zamax1buf[0] - 2.0;
        CORE_zamax1buf[thidx << 1] = -1.0;  /* signal thread 0 that this thread is done reading */

        if (thidx == *thwinner)
            ipiv[0] = pividx;

        while (CORE_zamax1buf[0] != -1.0) { /* wait for thread 0 to finish */
        }
    }
}

static void
CORE_zbarrier_thread(int thidx, int thcnt) {
    int idum1, idum2;
    PLASMA_Complex64_t ddum2;
    /* it's probably faster to implement a dedicated barrier */
    CORE_zamax1_thread( 1.0, thidx, thcnt, &idum1, &ddum2, 0, &idum2 );
}

static void 
CORE_zgetrf_rectil_rec(const PLASMA_desc A, int *IPIV, int *info, 
                       const int thidx, const int thcnt, const int column, const int width, 
                       const int ft, const int lt)
{
    int ld, jp, n1, n2, lm;
    int ip, j, it, i, ldft;
    int max_i, max_it, thwin;
    PLASMA_Complex64_t zone  = 1.0;
    PLASMA_Complex64_t mzone = -1.0;
    PLASMA_Complex64_t tmp1;
    PLASMA_Complex64_t tmp2;
    PLASMA_Complex64_t pivval;
    PLASMA_Complex64_t *Atop, *Atop2, *U, *L;
    double             abstmp1;
    int offset = A.i;

    ldft = BLKLDD(A, 0);
    Atop = A(0, 0) + column * ldft;

#if 0
    CORE_zbarrier_thread( thidx, thcnt );
    if (thidx == 0)
    {
        fprintf(stderr, "\n ------  column = %d / width = %d -------\n", column, width);
        int i, j;
        for(j=0;j<4;j++){
            for(i=0;i<4;i++){
                fprintf(stderr, "%e ", ((PLASMA_Complex64_t*)A(0, 0))[j*4+i]);
            }
            for(i=0;i<4;i++){
                fprintf(stderr, "%e ", ((PLASMA_Complex64_t*)A(1, 0))[j*4+i]);
            }
            fprintf(stderr, "\n");
        }
    }
    CORE_zbarrier_thread( thidx, thcnt );
#endif

    /* Assumption: N = min( M, N ); */
    if (width > 1) {
        n1 = width / 2;
        n2 = width - n1;
        
        Atop2 = Atop + n1 * ldft;
        
        CORE_zgetrf_rectil_rec( A, IPIV, info,
                                thidx, thcnt, column, n1, ft, lt );
        if ( *info != 0 )
            return;

        CORE_zbarrier_thread( thidx, thcnt );
        
        if (thidx == 0)
        {
            /* Swap to the right */
            int *lipiv = IPIV+column;
            int idxMax = column+n1;
            for (j = column; j < idxMax; ++j, ++lipiv) {
                ip = (*lipiv) - offset - 1;
                if ( ip != j )
                {
                    it = ip / A.mb;
                    i  = ip % A.mb;
                    ld = BLKLDD(A, it);
                    cblas_zswap(n2, Atop2                     + j, ldft,
                                    A(it, 0) + (column+n1)*ld + i, ld   );
                }
            }   
            
            /* Trsm on the uppert part */
            U = Atop2 + column;
            cblas_ztrsm( CblasColMajor, CblasLeft, CblasLower, 
                         CblasNoTrans, CblasUnit,
                         n1, n2, CBLAS_SADDR(zone), 
                         Atop  + column, ldft, 
                         U,              ldft );

            /* need to wait for pivoting and triangular solve to finish */
            CORE_zbarrier_thread( thidx, thcnt );

            L = Atop + column + n1;
            lm = ft == A.mt-1 ? A.m - ft * A.mb : A.mb;
            cblas_zgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, 
                         lm-column-n1, n2, n1,
                         CBLAS_SADDR(mzone), L,      ldft, 
                                             U,      ldft, 
                         CBLAS_SADDR(zone),  U + n1, ldft );

            /* Update */
            for( it = ft+1; it < lt; it++)
             {
                 ld = BLKLDD( A, it );
                 L  = A( it, 0 ) + column * ld;
                 lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;
                 cblas_zgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, 
                              lm, n2, n1,
                              CBLAS_SADDR(mzone), L,          ld, 
                                                  U,          ldft, 
                              CBLAS_SADDR(zone),  L + n1*ld,  ld );
             }
        }
        else 
        {
            /* need to wait for pivoting and triangular solve to finish */
            CORE_zbarrier_thread( thidx, thcnt );
            
            U = Atop2 + column;

            /* Update */
            for( it = ft; it < lt; it++)
            {
                ld = BLKLDD( A, it );
                L  = A( it, 0 ) + column * ld;
                lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;
                cblas_zgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, 
                             lm, n2, n1,
                             CBLAS_SADDR(mzone), L,         ld, 
                                                 U,         ldft, 
                             CBLAS_SADDR(zone),  L + n1*ld, ld );
            }
        }
      
        CORE_zgetrf_rectil_rec( A, IPIV, info,
                                thidx, thcnt, column+n1, n2, ft, lt );
        if ( *info != 0 )
            return;

        if ( thidx == 0 )
        {
            /* Swap to the left */
            int *lipiv = IPIV+column+n1;
            int idxMax = column+width;
            for (j = column+n1; j < idxMax; ++j, ++lipiv) {
                ip = (*lipiv) - offset - 1;
                if ( ip != j )
                {
                    it = ip / A.mb;
                    i  = ip % A.mb;
                    ld = BLKLDD(A, it);
                    cblas_zswap(n1, Atop + j,                 ldft,
                                    A(it, 0) + column*ld + i, ld  );
                }
            }   
        }

    } else {
        CORE_zbarrier_thread( thidx, thcnt );
      
        tmp2 = Atop[column]; /* all threads read the pivot element in case they need it */

        /* First tmp1 */
        ld = BLKLDD(A, ft);
        Atop2  = A( ft, 0 ) + ld * column;
        lm     = ft == A.mt-1 ? A.m - ft * A.mb : A.mb;
        max_it = ft;

        if (thidx == 0) {
            max_i = cblas_izamax( lm-column, Atop2+column, 1 ) + column;
        } else {
            max_i = cblas_izamax( lm,        Atop2,        1 );
        }
        tmp1    = Atop2[max_i];
        abstmp1 = cabs(tmp1);

        /* Update */
        for( it = ft+1; it < lt; it++)
        {
            ld = BLKLDD(A, it);
            Atop2 = A( it, 0 ) + ld * column;
            lm   = it == A.mt-1 ? A.m - it * A.mb : A.mb;
            jp   = cblas_izamax( lm, Atop2, 1 );
            if (  cabs(Atop2[jp]) > abstmp1 ) {
                tmp1 = Atop2[jp];
                abstmp1 = cabs(tmp1);
                max_i  = jp;
                max_it = it;
            }
        }
        
        jp = offset + max_it*A.mb + max_i;
        CORE_zamax1_thread( tmp1, thidx, thcnt, &thwin, 
                            &pivval, jp + 1, IPIV + column );
      
        if ( thidx == 0 )
        {
            Atop[column] = pivval; /* all threads have the pivot element: no need for synchronization */
            if ( pivval != 0.0 ) {
                if ( cabs(pivval) >= sfmin ) {
                    pivval = 1.0 / pivval;
      
                    /*
                     * We guess than we never enter the function with m == A.mt-1 
                     * because it means that there is only one thread 
                     */
                    lm = ft == A.mt-1 ? A.m - ft * A.mb : A.mb;
                    cblas_zscal( lm - column - 1, CBLAS_SADDR(pivval), Atop+column+1, 1 );
                    
                    for( it = ft+1; it < lt; it++)
                    {
                        ld = BLKLDD(A, it);
                        Atop2 = A( it, 0 ) + column * ld;
                        lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;
                        cblas_zscal( lm, CBLAS_SADDR(pivval), Atop2, 1 );
                    }
                } else {
                    /*
                     * We guess than we never enter the function with m == A.mt-1 
                     * because it means that there is only one thread 
                     */
                    int i;
                    Atop2 = Atop + column + 1;
                    lm = ft == A.mt-1 ? A.m - ft * A.mb : A.mb;

                    for( i=0; i < lm-column-1; i++, Atop2++)
                        *Atop2 = *Atop2 / pivval;

                    for( it = ft+1; it < lt; it++)
                    {
                        ld = BLKLDD(A, it);
                        Atop2 = A( it, 0 ) + column * ld;
                        lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;

                        for( i=0; i < lm; i++, Atop2++)
                            *Atop2 = *Atop2 / pivval;
                    }
                }
            } else {
                *info = column + 1;
                return;
            }
        } 
        else
        {
            if ( pivval != 0.0 ) {
                if ( cabs(pivval) >= sfmin ) {
                    pivval = 1.0 / pivval;
      
                    for( it = ft; it < lt; it++)
                    {
                        ld = BLKLDD(A, it);
                        Atop2 = A( it, 0 ) + column * ld;
                        lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;
                        cblas_zscal( lm, CBLAS_SADDR(pivval), Atop2, 1 );
                    }
                } else {
                    /*
                     * We guess than we never enter the function with m == A.mt-1 
                     * because it means that there is only one thread 
                     */
                    int i;
                    for( it = ft; it < lt; it++)
                    {
                        ld = BLKLDD(A, it);
                        Atop2 = A( it, 0 ) + column * ld;
                        lm = it == A.mt-1 ? A.m - it * A.mb : A.mb;

                        for( i=0; i < lm; i++, Atop2++)
                            *Atop2 = *Atop2 / pivval;
                    }
                }
            } else {
                *info = column + 1;
                return;
            }
        }
      
        if (thwin == thidx) { /* the thread that owns the best pivot */
          if ( jp-offset != column ) /* if there is a need to exchange the pivot */
            {
                ld = BLKLDD(A, max_it);
                Atop2 = A( max_it, 0 ) + column * ld + max_i;
                *Atop2 = tmp2 * pivval;
            }
        }
      
        CORE_zbarrier_thread( thidx, thcnt );
    }
}

#if defined(PLASMA_HAVE_WEAK)
#pragma weak CORE_zgetrf_rectil = PCORE_zgetrf_rectil
#define CORE_zgetrf_rectil PCORE_zgetrf_rectil
#endif
int CORE_zgetrf_rectil(const PLASMA_desc A, int *IPIV, int *info)
{
    int ft, lt; 
    int thidx = info[1];
    int thcnt = min( info[2], A.mt );
    
    if ( A.nt > 1 ) {
        coreblas_error(1, "Illegal value of A.nt");
        return -1;
    }

    if ( thidx >= thcnt )
      return 0;

    int q = A.mt / thcnt;
    int r = A.mt % thcnt;

    if (thidx < r) {
        q++;
        ft = thidx * q;
        lt = ft + q;
    } else {
        ft = r * (q + 1) + (thidx - r) * q;
        lt = ft + q;
        lt = min( lt, A.mt );
    }
    
    info[0] = 0;
    CORE_zgetrf_rectil_rec( A, IPIV, info,
                            thidx, thcnt, 0, A.n, ft, lt);
   
    return info[0];
}

/***************************************************************************//**
 *
 **/
void QUARK_CORE_zgetrf_rectil(Quark *quark, Quark_Task_Flags *task_flags,
                              PLASMA_desc A, PLASMA_Complex64_t *Amn, int size,
                              int *IPIV,
                              PLASMA_sequence *sequence, PLASMA_request *request,
                              PLASMA_bool check_info, int iinfo,
                              int nbthread)
{
    DAG_CORE_GETRF;
    QUARK_Insert_Task(quark, CORE_zgetrf_rectil_quark, task_flags,
        sizeof(PLASMA_desc),                &A,             VALUE,
        sizeof(PLASMA_Complex64_t)*size,     Amn,               INOUT,
        sizeof(int)*A.n,                     IPIV,              OUTPUT,
        sizeof(PLASMA_sequence*),           &sequence,      VALUE,
        sizeof(PLASMA_request*),            &request,       VALUE,
        sizeof(PLASMA_bool),                &check_info,    VALUE,
        sizeof(int),                        &iinfo,         VALUE,
        sizeof(int),                        &nbthread,      VALUE,
        0);
}

/***************************************************************************//**
 *
 **/
#if defined(PLASMA_HAVE_WEAK)
#pragma weak CORE_zgetrf_rectil_quark = PCORE_zgetrf_rectil_quark
#define CORE_zgetrf_rectil_quark PCORE_zgetrf_rectil_quark
#endif
void CORE_zgetrf_rectil_quark(Quark* quark)
{
    PLASMA_desc A;
    PLASMA_Complex64_t *Amn;
    int *IPIV;
    PLASMA_sequence *sequence;
    PLASMA_request *request;
    PLASMA_bool check_info;
    int iinfo;

    int info[3];
    int maxthreads;

    quark_unpack_args_8(quark, A, Amn, IPIV, sequence, request, 
                        check_info, iinfo, maxthreads );

    info[1] = QUARK_Get_RankInTask(quark);
    info[2] = maxthreads;

    CORE_zgetrf_rectil( A, IPIV, info );
    if (info[1] == 0 && info[0] != PLASMA_SUCCESS && check_info)
        plasma_sequence_flush(quark, sequence, request, iinfo + info[0] );
}
