/**
 *
 * @file pstrsm.c
 *
 *  PLASMA auxiliary routines
 *  PLASMA is a software package provided by Univ. of Tennessee,
 *  Univ. of California Berkeley and Univ. of Colorado Denver
 *
 * @version 2.2.0
 * @author Jakub Kurzak
 * @author Hatem Ltaief
 * @date 2009-11-15
 *
 **/
#include "common.h"

#define A(m,n) BLKADDR(A, float, m, n)
#define B(m,n) BLKADDR(B, float, m, n)
/***************************************************************************//**
 *  Parallel triangular solve - static scheduling
 **/
void plasma_pstrsm(plasma_context_t *plasma)
{
    PLASMA_enum side;
    PLASMA_enum uplo;
    PLASMA_enum transA;
    PLASMA_enum diag;
    float alpha;
    PLASMA_desc A;
    PLASMA_desc B;
    PLASMA_sequence *sequence;
    PLASMA_request *request;

    int k, m, n;
    int next_k;
    int next_m;
    int next_n;

    plasma_unpack_args_9(side, uplo, transA, diag, alpha, A, B, sequence, request);
    if (sequence->status != PLASMA_SUCCESS)
        return;
    ss_init(B.mt, B.nt, -1);

    k = 0;
    m = PLASMA_RANK;
    while (m >= A.nt) {
        k++;
        m = m-A.nt+k;
    }
    n = 0;

    while (k < A.nt && m < A.nt) {
        next_n = n;
        next_m = m;
        next_k = k;

        next_n++;
        if (next_n >= B.nt) {
            next_m += PLASMA_SIZE;
            while (next_m >= A.nt && next_k < A.nt) {
                next_k++;
                next_m = next_m-A.nt+next_k;
            }
            next_n = 0;
        }

        if (m == k)
        {
            ss_cond_wait(m, n, k-1);
            if (uplo == PlasmaLower) {
                if (transA == PlasmaNoTrans)
                    CORE_strsm(
                        PlasmaLeft, PlasmaLower,
                        PlasmaNoTrans, diag,
                        k == A.nt-1 ? A.n-k*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        alpha, A(k, k), A.nb,
                               B(k, n), B.nb);
                else
                    CORE_strsm(
                        PlasmaLeft, PlasmaLower,
                        PlasmaTrans, diag,
                        k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        alpha, A(A.nt-1-k, A.nt-1-k), A.nb,
                               B(A.nt-1-k, n), B.nb);
            }
            else {
                if (transA == PlasmaNoTrans)
                    CORE_strsm(
                        PlasmaLeft, PlasmaUpper,
                        PlasmaNoTrans, diag,
                        k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        alpha, A(A.nt-1-k, A.nt-1-k), A.nb,
                               B(A.nt-1-k, n), A.nb);
                else
                    CORE_strsm(
                        PlasmaLeft, PlasmaUpper,
                        PlasmaTrans, diag,
                        k == A.nt-1 ? A.n-k*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        alpha, A(k, k), A.nb,
                               B(k, n), B.nb);
            }
            ss_cond_set(k, n, k);
        }
        else
        {
            ss_cond_wait(k, n, k);
            ss_cond_wait(m, n, k-1);
            if (uplo == PlasmaLower) {
                if (transA == PlasmaNoTrans)
                    CORE_sgemm(
                        PlasmaNoTrans, PlasmaNoTrans,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        A.nb,
                       -alpha, A(m, k), A.nb,
                               B(k, n), B.nb,
                        alpha, B(m, n), B.nb);
                else
                    CORE_sgemm(
                        PlasmaTrans, PlasmaNoTrans,
                        A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
                       -alpha, A(A.nt-1-k, A.nt-1-m), A.nb,
                               B(A.nt-1-k, n), B.nb,
                        alpha, B(A.nt-1-m, n), B.nb);
            }
            else {
                if (transA == PlasmaNoTrans)
                    CORE_sgemm(
                        PlasmaNoTrans, PlasmaNoTrans,
                        A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
                       -alpha, A(A.nt-1-m, A.nt-1-k), A.nb,
                               B(A.nt-1-k, n), B.nb,
                        alpha, B(A.nt-1-m, n), B.nb);
                else
                    CORE_sgemm(
                        PlasmaTrans, PlasmaNoTrans,
                        m == A.nt-1 ? A.n-m*A.nb : A.nb,
                        n == B.nt-1 ? B.n-n*B.nb : B.nb,
                        A.nb,
                       -alpha, A(k, m), A.nb,
                               B(k, n), B.nb,
                        alpha, B(m, n), B.nb);
            }
            ss_cond_set(m, n, k);
        }
        n = next_n;
        m = next_m;
        k = next_k;
    }
    ss_finalize();
}

/***************************************************************************//**
 *  Parallel triangular solve - dynamic scheduling
 **/
void plasma_pstrsm_quark(PLASMA_enum side, PLASMA_enum uplo, PLASMA_enum transA, PLASMA_enum diag,
                          float alpha, PLASMA_desc A, PLASMA_desc B,
                          PLASMA_sequence *sequence, PLASMA_request *request)
{
    int k, m, n;
    plasma_context_t *plasma;
    PLASMA_enum plasma_left = PlasmaLeft;
    PLASMA_enum plasma_lower = PlasmaLower;
    PLASMA_enum plasma_upper = PlasmaUpper;
    PLASMA_enum plasma_no_trans = PlasmaNoTrans;
    PLASMA_enum plasma__trans = PlasmaTrans;
    int temp1, temp2;
    float minus_alpha = -alpha;
    Quark_Task_Flags task_flags = Quark_Task_Flags_Initializer;

    plasma = plasma_context_self();
    if (sequence->status != PLASMA_SUCCESS)
        return;
    QUARK_Task_Flag_Set(&task_flags, TASK_SEQUENCE, (intptr_t)sequence->quark_sequence);

    for (k = 0; k < A.nt; k++)
    {
        for (n = 0; n < B.nt; n++)
        {
            if (uplo == PlasmaLower) {
                if (transA == PlasmaNoTrans) {
                    temp1 = A.n-k*A.nb;
                    temp2 = B.n-n*B.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_strsm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_left,                 VALUE,
                        sizeof(PLASMA_enum),                  &plasma_lower,                VALUE,
                        sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                        sizeof(PLASMA_enum),                  &diag,                        VALUE,
                        sizeof(int),                          k == A.nt-1 ? &temp1 : &A.nb, VALUE,
                        sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                        sizeof(float),           &alpha,                       VALUE,
                        sizeof(float)*A.mb*A.nb, A(k, k),                          INPUT,
                        sizeof(int),                          &A.nb,                        VALUE,
                        sizeof(float)*B.mb*B.nb, B(k, n),                          INOUT | LOCALITY,
                        sizeof(int),                          &B.nb,                        VALUE,
                        0);
                }
                else {
                    temp1 = A.n-(A.nt-1)*A.nb;
                    temp2 = B.n-n*B.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_strsm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_left,                 VALUE,
                        sizeof(PLASMA_enum),                  &plasma_lower,                VALUE,
                        sizeof(PLASMA_enum),                  &plasma__trans,           VALUE,
                        sizeof(PLASMA_enum),                  &diag,                        VALUE,
                        sizeof(int),                          k == 0      ? &temp1 : &A.nb, VALUE,
                        sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                        sizeof(float),           &alpha,                       VALUE,
                        sizeof(float)*A.mb*A.nb, A(A.nt-1-k, A.nt-1-k),            INPUT,
                        sizeof(int),                          &A.nb,                        VALUE,
                        sizeof(float)*B.mb*B.nb, B(A.nt-1-k, n),                   INOUT | LOCALITY,
                        sizeof(int),                          &B.nb,                        VALUE,
                        0);
                }
            }
            else {
                if (transA == PlasmaNoTrans) {
                    temp1 = A.n-(A.nt-1)*A.nb;
                    temp2 = B.n-n*B.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_strsm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_left,                 VALUE,
                        sizeof(PLASMA_enum),                  &plasma_upper,                VALUE,
                        sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                        sizeof(PLASMA_enum),                  &diag,                        VALUE,
                        sizeof(int),                          k == 0      ? &temp1 : &A.nb, VALUE,
                        sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                        sizeof(float),           &alpha,                       VALUE,
                        sizeof(float)*A.mb*A.nb, A(A.nt-1-k, A.nt-1-k),            INPUT,
                        sizeof(int),                          &A.nb,                        VALUE,
                        sizeof(float)*B.mb*B.nb, B(A.nt-1-k, n),                   INOUT | LOCALITY,
                        sizeof(int),                          &A.nb,                        VALUE,
                        0);
                }
                else {
                    temp1 = A.n-k*A.nb;
                    temp2 = B.n-n*B.nb;
                    QUARK_Insert_Task(plasma->quark, CORE_strsm_quark, &task_flags,
                        sizeof(PLASMA_enum),                  &plasma_left,                 VALUE,
                        sizeof(PLASMA_enum),                  &plasma_upper,                VALUE,
                        sizeof(PLASMA_enum),                  &plasma__trans,           VALUE,
                        sizeof(PLASMA_enum),                  &diag,                        VALUE,
                        sizeof(int),                          k == A.nt-1 ? &temp1 : &A.nb, VALUE,
                        sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                        sizeof(float),           &alpha,                       VALUE,
                        sizeof(float)*A.mb*A.nb, A(k, k),                          INPUT,
                        sizeof(int),                          &A.nb,                        VALUE,
                        sizeof(float)*B.mb*B.nb, B(k, n),                          INOUT | LOCALITY,
                        sizeof(int),                          &B.nb,                        VALUE,
                        0);
                }
            }
        }
        for (m = k+1; m < A.nt; m++)
        {
            for (n = 0; n < B.nt; n++)
            {
                if (uplo == PlasmaLower) {
                    if (transA == PlasmaNoTrans) {
                        temp1 = A.n-m*A.nb;
                        temp2 = B.n-n*B.nb;
                        QUARK_Insert_Task(plasma->quark, CORE_sgemm_quark, &task_flags,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(int),                          m == A.nt-1 ? &temp1 : &A.nb, VALUE,
                            sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float),           &minus_alpha,                 VALUE,
                            sizeof(float)*A.mb*A.nb, A(m, k),                          INPUT,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float)*B.mb*B.nb, B(k, n),                          INPUT | LOCALITY,
                            sizeof(int),                          &B.nb,                        VALUE,
                            sizeof(float),           &alpha,                       VALUE,
                            sizeof(float)*B.mb*B.nb, B(m, n),                          INOUT,
                            sizeof(int),                          &B.nb,                        VALUE,
                            0);
                    }
                    else {
                        temp1 = B.n-n*B.nb;
                        temp2 = A.n-(A.nt-1)*A.nb;
                        QUARK_Insert_Task(plasma->quark, CORE_sgemm_quark, &task_flags,
                            sizeof(PLASMA_enum),                  &plasma__trans,           VALUE,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(int),                          n == B.nt-1 ? &temp1 : &B.nb, VALUE,
                            sizeof(int),                          k == 0      ? &temp2 : &A.nb, VALUE,
                            sizeof(float),           &minus_alpha,                 VALUE,
                            sizeof(float)*A.mb*A.nb, A(A.nt-1-k, A.nt-1-m),            INPUT,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float)*B.mb*B.nb, B(A.nt-1-k, n),                   INPUT | LOCALITY,
                            sizeof(int),                          &B.nb,                        VALUE,
                            sizeof(float),           &alpha,                       VALUE,
                            sizeof(float)*B.mb*B.nb, B(A.nt-1-m, n),                   INOUT,
                            sizeof(int),                          &B.nb,                        VALUE,
                            0);
                    }
                }
                else {
                    if (transA == PlasmaNoTrans) {
                        temp1 = B.n-n*B.nb;
                        temp2 = A.n-(A.nt-1)*A.nb;
                        QUARK_Insert_Task(plasma->quark, CORE_sgemm_quark, &task_flags,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(int),                          n == B.nt-1 ? &temp1 : &B.nb, VALUE,
                            sizeof(int),                          k == 0      ? &temp2 : &A.nb, VALUE,
                            sizeof(float),           &minus_alpha,                 VALUE,
                            sizeof(float)*A.mb*A.nb, A(A.nt-1-m, A.nt-1-k),            INPUT,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float)*B.mb*B.nb, B(A.nt-1-k, n),                   INPUT | LOCALITY,
                            sizeof(int),                          &B.nb,                        VALUE,
                            sizeof(float),           &alpha,                       VALUE,
                            sizeof(float)*B.mb*B.nb, B(A.nt-1-m, n),                   INOUT,
                            sizeof(int),                          &B.nb,                        VALUE,
                            0);
                    }
                    else {
                        temp1 = A.n-m*A.nb;
                        temp2 = B.n-n*B.nb;
                        QUARK_Insert_Task(plasma->quark, CORE_sgemm_quark, &task_flags,
                            sizeof(PLASMA_enum),                  &plasma__trans,           VALUE,
                            sizeof(PLASMA_enum),                  &plasma_no_trans,             VALUE,
                            sizeof(int),                          m == A.nt-1 ? &temp1 : &A.nb, VALUE,
                            sizeof(int),                          n == B.nt-1 ? &temp2 : &B.nb, VALUE,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float),           &minus_alpha,                 VALUE,
                            sizeof(float)*A.mb*A.nb, A(k, m),                          INPUT,
                            sizeof(int),                          &A.nb,                        VALUE,
                            sizeof(float)*B.mb*B.nb, B(k, n),                          INPUT | LOCALITY,
                            sizeof(int),                          &B.nb,                        VALUE,
                            sizeof(float),           &alpha,                       VALUE,
                            sizeof(float)*B.mb*B.nb, B(m, n),                          INOUT,
                            sizeof(int),                          &B.nb,                        VALUE,
                            0);
                    }
                }
            }
        }
    }
}
