/**
 *
 * @file pcpotrf.c
 *
 *  PLASMA auxiliary routines
 *  PLASMA is a software package provided by Univ. of Tennessee,
 *  Univ. of California Berkeley and Univ. of Colorado Denver
 *
 * @version 2.2.0
 * @author Jakub Kurzak
 * @author Hatem Ltaief
 * @date 2009-11-15
 *
 **/
#include "common.h"

#define A(m,n) BLKADDR(A, PLASMA_Complex32_t, m, n)
/***************************************************************************//**
 *  Parallel Cholesky factorization - static scheduling
 **/
void plasma_pcpotrf(plasma_context_t *plasma)
{
    PLASMA_enum uplo;
    PLASMA_desc A;
    PLASMA_sequence *sequence;
    PLASMA_request *request;

    int k, m, n;
    int next_k;
    int next_m;
    int next_n;
    int info;

    plasma_unpack_args_4(uplo, A, sequence, request);
    if (sequence->status != PLASMA_SUCCESS)
        return;

    ss_init(A.nt, A.nt, 0);

    k = 0;
    m = PLASMA_RANK;
    while (m >= A.nt) {
        k++;
        m = m-A.nt+k;
    }
    n = 0;

    while (k < A.nt && m < A.nt && !ss_aborted()) {
        next_n = n;
        next_m = m;
        next_k = k;

        next_n++;
        if (next_n > next_k) {
            next_m += PLASMA_SIZE;
            while (next_m >= A.nt && next_k < A.nt) {
                next_k++;
                next_m = next_m-A.nt+next_k;
            }
            next_n = 0;
        }

        if (m == k) {
            if (n == k) {
                if (uplo == PlasmaLower)
                    CORE_cpotrf(
                        PlasmaLower,
                        k == A.nt-1 ? A.n-k*A.nb : A.nb,
                        A(k, k), A.nb,
                        &info);
                else
                    CORE_cpotrf(
                        PlasmaUpper,
                        k == A.nt-1 ? A.n-k*A.nb : A.nb,
                        A(k, k), A.nb,
                        &info);
                if (info != 0) {
                    plasma_request_fail(sequence, request, info + A.nb*k);
                    ss_abort();
                }
                ss_cond_set(k, k, 1);
            }
            else {
                ss_cond_wait(k, n, 1);
                if (uplo == PlasmaLower)
                    CORE_cherk(
                         PlasmaLower, PlasmaNoTrans,
                         k == A.nt-1 ? A.n-k*A.nb : A.nb,
                         A.nb,
                        -1.0, A(k, n), A.nb,
                         1.0, A(k, k), A.nb);
                else
                    CORE_cherk(
                         PlasmaUpper, PlasmaConjTrans,
                         k == A.nt-1 ? A.n-k*A.nb : A.nb,
                         A.nb,
                        -1.0, A(n, k), A.nb,
                         1.0, A(k, k), A.nb);
            }
        }
        else {
            if (n == k) {
                ss_cond_wait(k, k, 1);
                if (uplo == PlasmaLower)
                    CORE_ctrsm(
                        PlasmaRight, PlasmaLower, PlasmaConjTrans, PlasmaNonUnit,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        A.nb,
                        1.0, A(k, k), A.nb,
                             A(m, k), A.nb);
                else
                    CORE_ctrsm(
                        PlasmaLeft, PlasmaUpper, PlasmaConjTrans, PlasmaNonUnit,
                        A.nb,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        1.0, A(k, k), A.nb,
                             A(k, m), A.nb);
                ss_cond_set(m, k, 1);
            }
            else {
                ss_cond_wait(k, n, 1);
                ss_cond_wait(m, n, 1);
                if (uplo == PlasmaLower)
                    CORE_cgemm(
                        PlasmaNoTrans, PlasmaConjTrans,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        A.nb,
                        A.nb,
                       -1.0, A(m, n), A.nb,
                             A(k, n), A.nb,
                        1.0, A(m, k), A.nb);
                else
                    CORE_cgemm(
                        PlasmaConjTrans, PlasmaNoTrans,
                        A.nb,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        A.nb,
                       -1.0, A(n, k), A.nb,
                             A(n, m), A.nb,
                        1.0, A(k, m), A.nb);
            }
        }
        n = next_n;
        m = next_m;
        k = next_k;
    }
    ss_finalize();
}

/***************************************************************************//**
 *  Parallel Cholesky factorization - dynamic scheduling
 **/
void plasma_pcpotrf_quark(PLASMA_enum uplo, PLASMA_desc A, PLASMA_sequence *sequence, PLASMA_request *request)
{
    int k, m, n;
    plasma_context_t *plasma;
    PLASMA_enum plasma_lower = PlasmaLower;
    PLASMA_enum plasma_upper = PlasmaUpper;
    PLASMA_enum plasma_right = PlasmaRight;
    PLASMA_enum plasma_left = PlasmaLeft;
    PLASMA_enum plasma_no_trans = PlasmaNoTrans;
    PLASMA_enum plasma_conjf_trans = PlasmaConjTrans;
    PLASMA_enum plasma_non_unit = PlasmaNonUnit;
    int temp;
    PLASMA_Complex32_t zone  = (PLASMA_Complex32_t)1.0;
    PLASMA_Complex32_t mzone = (PLASMA_Complex32_t)-1.0;
    PLASMA_Complex32_t done  = (float)1.0;
    PLASMA_Complex32_t mdone = (float)-1.0;
    int iinfo;  // value to be added to the error code returned from the kernel
    Quark_Task_Flags task_flags = Quark_Task_Flags_Initializer;

    plasma = plasma_context_self();
    if (sequence->status != PLASMA_SUCCESS)
        return;
    QUARK_Task_Flag_Set(&task_flags, TASK_SEQUENCE, (intptr_t)sequence->quark_sequence);

    for (k = 0; k < A.nt; k++)
    {
        if (uplo == PlasmaLower) {
            temp = A.n-k*A.nb;
            iinfo = A.nb*k;
            QUARK_Insert_Task(plasma->quark, CORE_cpotrf_quark, &task_flags,
                sizeof(PLASMA_enum),                  &plasma_lower,               VALUE,
                sizeof(int),                          k == A.nt-1 ? &temp : &A.nb, VALUE,
                sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, k),                         INOUT | LOCALITY,
                sizeof(int),                          &A.nb,                       VALUE,
                sizeof(PLASMA_sequence*),             &sequence,                   VALUE,
                sizeof(PLASMA_request*),              &request,                    VALUE,
                sizeof(int),                          &iinfo,                      VALUE,
                0);
        }
        else {
            temp = A.n-k*A.nb;
            iinfo = A.nb*k;
            QUARK_Insert_Task(plasma->quark, CORE_cpotrf_quark, &task_flags,
                sizeof(PLASMA_enum),                  &plasma_upper,               VALUE,
                sizeof(int),                          k == A.nt-1 ? &temp : &A.nb, VALUE,
                sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, k),                         INOUT | LOCALITY,
                sizeof(int),                          &A.nb,                       VALUE,
                sizeof(PLASMA_sequence*),             &sequence,                   VALUE,
                sizeof(PLASMA_request*),              &request,                    VALUE,
                sizeof(int),                          &iinfo,                      VALUE,
                0);
        }
        for (m = k+1; m < A.nt; m++)
        {
            if (uplo == PlasmaLower) {
                temp = A.n-m*A.nb;
                QUARK_Insert_Task(plasma->quark, CORE_ctrsm_quark, &task_flags,
                    sizeof(PLASMA_enum),                  &plasma_right,               VALUE,
                    sizeof(PLASMA_enum),                  &plasma_lower,               VALUE,
                    sizeof(PLASMA_enum),                  &plasma_conjf_trans,          VALUE,
                    sizeof(PLASMA_enum),                  &plasma_non_unit,            VALUE,
                    sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(PLASMA_Complex32_t),           &zone,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, k),                         INPUT | LOCALITY,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, k),                         INOUT,
                    sizeof(int),                          &A.nb,                       VALUE,
                    0);
            }
            else {
                temp = A.n-m*A.nb;
                QUARK_Insert_Task(plasma->quark, CORE_ctrsm_quark, &task_flags,
                    sizeof(PLASMA_enum),                  &plasma_left,                VALUE,
                    sizeof(PLASMA_enum),                  &plasma_upper,               VALUE,
                    sizeof(PLASMA_enum),                  &plasma_conjf_trans,          VALUE,
                    sizeof(PLASMA_enum),                  &plasma_non_unit,            VALUE,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                    sizeof(PLASMA_Complex32_t),           &zone,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, k),                         INPUT | LOCALITY,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, m),                         INOUT,
                    sizeof(int),                          &A.nb,                       VALUE,
                    0);
            }
        }
        for (m = k+1; m < A.nt; m++)
        {
            if (uplo == PlasmaLower) {
                temp = A.n-m*A.nb;
                QUARK_Insert_Task(plasma->quark, CORE_cherk_quark, &task_flags,
                    sizeof(PLASMA_enum),                  &plasma_lower,               VALUE,
                    sizeof(PLASMA_enum),                  &plasma_no_trans,            VALUE,
                    sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(float),                       &mdone,                      VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, k),                         INPUT | LOCALITY,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(float),                       &done,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, m),                         INOUT,
                    sizeof(int),                          &A.nb,                       VALUE,
                    0);
            }
            else {
                temp = A.n-m*A.nb;
                QUARK_Insert_Task(plasma->quark, CORE_cherk_quark, &task_flags,
                    sizeof(PLASMA_enum),                  &plasma_upper,               VALUE,
                    sizeof(PLASMA_enum),                  &plasma_conjf_trans,          VALUE,
                    sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(float),                       &mdone,                      VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, m),                         INPUT | LOCALITY,
                    sizeof(int),                          &A.nb,                       VALUE,
                    sizeof(float),                       &done,                       VALUE,
                    sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, m),                         INOUT,
                    sizeof(int),                          &A.nb,                       VALUE,
                    0);
            }
            for (n = k+1; n < m; n++)
            {
                if (uplo == PlasmaLower) {
                    temp = A.n-m*A.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_cgemm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_no_trans,            VALUE,
                        sizeof(PLASMA_enum),                  &plasma_conjf_trans,          VALUE,
                        sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t),           &mzone,                      VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, k),                         INPUT,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(n, k),                         INPUT | LOCALITY,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t),           &zone,                       VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(m, n),                         INOUT,
                        sizeof(int),                          &A.nb,                       VALUE,
                        0);
                }
                else {
                    temp = A.n-m*A.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_cgemm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_conjf_trans,          VALUE,
                        sizeof(PLASMA_enum),                  &plasma_no_trans,            VALUE,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(int),                          m == A.nt-1 ? &temp : &A.nb, VALUE,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t),           &mzone,                      VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, n),                         INPUT | LOCALITY,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(k, m),                         INPUT,
                        sizeof(int),                          &A.nb,                       VALUE,
                        sizeof(PLASMA_Complex32_t),           &zone,                       VALUE,
                        sizeof(PLASMA_Complex32_t)*A.mb*A.nb, A(n, m),                         INOUT,
                        sizeof(int),                          &A.nb,                       VALUE,
                        0);
                }
            }
        }
    }
}
