/* -*- C -*- */
/* slktstr.c
 */

#include <assert.h>
#include <lfc/lfci.h>

#include "lfctests.h"

#define DP(v) do{printf("%s(%d):%s=%g\n",__FILE__,__LINE__,#v,(double)(v));fflush(stdout);}while(0)

#define PREFIX0 getenv( "HOME" )
#define PREFIX "/tmp"

LFC_BEGIN_C_DECLS

static int
saxbgen(int m, int n, float *a, int lda, int mrhs, int nrhs, float *b,
	int ldb, int diagAdd) {
  int i, j, mn = m < n ? m : n; float rcp = 1.0f / RAND_MAX;

  srand( 1313 );

  for (j = 0; j < n; j++) {
    for (i = 0; i < m; i++)
      a[i + j * lda] = rcp * rand();
  }
  for (i = 0; i < mn; i++) a[i + i * lda] += diagAdd;

  for (j = 0; j < nrhs; j++)
    for (i = 0; i < mrhs; i++)
      b[i + j * ldb] = rcp * rand();

  return 0;
}	/* daxbgen */

/* Computes Frobenius norm of a matrix. */
static int
smatnmf(int m, int n, float *a, int lda, float *res) {
  int i, j; float anorm, v;

  anorm = 0.0f;
  for (j = 0; j < n; j++)
    for (i = 0; i < m; i++) {
      v = a[i + j * lda];
      anorm += v * v;
    }
  anorm = sqrt( anorm );

  if (res) *res = anorm;

  return 0;
}

static int
sgemset(int m, int n, float *a, int lda, float v) {
  int i, j;

  for (j = 0; j < n; j++)
    for (i = 0; i < m; i++)
      a[i + j * lda] = v;
  return 0;
}

static int
smatcpy(int m, int n, float *dst, int ldd, float *src, int lds) {
  int i, j;
  for (j = 0; j < n; j++)
    for (i = 0; i < m; i++)
      dst[i + j * ldd] = src[i + j * lds];
  return 0;
}

int
tsgels(char *trans, int m, int n, int nrhs, Options *opts, double *Uresd) {
  float *a, *b, *x, *y;
  enum LFC_Transpose lfctr = ((trans[0] == 'T' || trans[0] == 't') ? LFC_Trans : LFC_NoTrans);
  enum LFC_CBLAS_TRANSPOSE tr;
  float cd_1 = 1.0f, cd_0 = 0.0f, cd_n1 = -1.0f; int c__0 = 0;
  int n1, n2, mn, lda, ldb, ldx, ldy, info;
  float seps, anorm, xnorm, rnorm, resd, v;

  switch (lfctr) {
    case LFC_NoTrans: n1 = m; n2 = n; tr = LFC_CblasNoTrans; break;
    case LFC_Trans:   n1 = n; n2 = m; tr = LFC_CblasTrans; break;
  }

  mn = MAX( m, n );

  lda = m + 10;
  ldb = mn + 11;
  ldx = mn + 13;
  ldy = mn + 17;

  a = malloc( (sizeof *a) * lda * n );
  b = malloc( (sizeof *b) * ldb * nrhs );
  x = malloc( (sizeof *x) * ldx * nrhs );
  y = malloc( (sizeof *y) * ldy * nrhs );

  assert( a && b && x && y );

  /* Right hand side 'b' is generated from a known solution 'x': b := A*x
   * This is necessary only for an overdetermined system (more equations than
   * unknowns) so that an accurate solution to A*x=b exists (otherwise 'x'
   * might be minimal in 2-norm but still far off in terms of scaled residual).
   */

  saxbgen( m, n, a, lda, n2, nrhs, x, ldx, cd_0 );

  /* initialize 'b' to zero (prevents garbage to go into farther calculations) */
  sgemset( mn, nrhs, b, ldb, 0.0f );

  lfc_cblas_sgemm( LFC_CblasColMajor, tr, LFC_CblasNoTrans, n1, nrhs, n2,
		   cd_1, a, lda, x, ldx, cd_0, b, ldb );

  if (opts->use_file) {
    LFC_Status s; char *apath, *bpath; int i, j;
    LFC_Layout alout, blout;

    alout.row = alout.col = alout.opts =
    blout.row = blout.col = blout.opts = 0;
    alout.opts = LFC_RDONLY;

    apath = LFC_StrSubst1( "%s/qrf-mat", PREFIX );
    bpath = LFC_StrSubst1( "%s/rhs-mat", PREFIX0 );

    LFC_Matrix_swrite( apath, m, n, a, lda );
    LFC_Matrix_swrite( bpath, n1, nrhs, b, ldb );

    LFC_sgels_file_all( lfctr, m, n, nrhs, apath, &alout, bpath, &blout, &s );

    free( y );
    LFC_Matrix_sread( bpath, 0, 0, &y, &i, &j );
    ldy = i;
    assert( j == nrhs );

    remove( apath );
    remove( bpath );

    smatcpy( i, nrhs, b, ldb, y, ldy );

    ldy = mn + 17;
    y = malloc( (sizeof *y) * ldy * nrhs );
    assert( y );
  } else {
    LFC_sgels( LFC_ColMajor, lfctr, m, n, nrhs, a, lda, b, ldb, &v, &c__0, &info );
    if (info) printf( "LFC_sgels() -> %d\n", info );
  }

  smatcpy( mn, nrhs, x, ldx, b, ldb );

  saxbgen( m, n, a, lda, n2, nrhs, y, ldy, 0.0 );
  lfc_cblas_sgemm( LFC_CblasColMajor, tr, LFC_CblasNoTrans, n1, nrhs, n2,
		   cd_1, a, lda, y, ldy, cd_0, b, ldb );

  lfc_cblas_sgemm( LFC_CblasColMajor, tr, LFC_CblasNoTrans, n1, nrhs, n2,
		   cd_n1, a, lda, x, ldx, cd_1, b, ldb );

  seps = LFC_slamch( "Epsilon" );

  smatnmf( m,  n,    a, lda, &anorm );
  smatnmf( n1, nrhs, b, ldb, &rnorm );
  smatnmf( n2, nrhs, x, ldx, &xnorm );

  resd = rnorm / xnorm / anorm / MAX( m, n ) / seps;

  free( y ); free( x ); free( b ); free( a );

  if (Uresd) *Uresd = resd;

  if (resd < 10.0)
    return 0;

  return 1;
}	/* tsgels */

int
tsgesv(int n, int nrhs, Options *opts, double *Uresd) {
  float *a, *b, *x;
  int lda, ldb, ldx, info;
  float anorm, xnorm, rnorm, eps, resd;
  int *piv;

  lda = n + 10;
  ldb = n + 11;
  ldx = n + 13;

  a = malloc( (sizeof *a) * lda * n );
  b = malloc( (sizeof *b) * ldb * nrhs );
  x = malloc( (sizeof *x) * ldx * nrhs );
  piv = malloc( (sizeof *piv) * n );

  assert( a && b && x );

  /* generate system matrix */
  saxbgen( n, n, a, lda, n, nrhs, b, ldb, 0.0 );

  smatcpy( n, nrhs, x, ldx, b, ldb );

  if (opts->use_file) {
    LFC_Status s; char *apath, *bpath; int i, j;
    LFC_Layout alout, blout;

    alout.row = alout.col = alout.opts =
    blout.row = blout.col = blout.opts = 0;
    alout.opts = LFC_RDONLY;

    apath = LFC_StrSubst1( "%s/luf-mat", PREFIX );
    bpath = LFC_StrSubst1( "%s/rhs-mat", PREFIX0 );

    LFC_Matrix_swrite( apath, n, n, a, lda );
    LFC_Matrix_swrite( bpath, n, nrhs, b, ldb );

    LFC_sgesv_file_all( n, nrhs, apath, &alout, piv, bpath, &blout, &s );

    free( x );
    LFC_Matrix_sread( bpath, 0, 0, &x, &i, &j );
    ldx = i;

    remove( apath );
    remove( bpath );
  } else {
    LFC_sgesv( LFC_ColMajor, n, nrhs, a, lda, piv, x, ldx, &info );
    if (info) printf( "LFC_sgesv() -> %d\n", info );
  }

  /* regenerate system matrix */
  saxbgen( n, n, a, lda, n, nrhs, b, ldb, 0.0 );

  lfc_cblas_sgemm( LFC_CblasColMajor, LFC_CblasNoTrans, LFC_CblasNoTrans, n, nrhs, n,
		   -1.0f, a, lda, x, ldx, 1.0f, b, ldb );

  eps = LFC_slamch( "Epsilon" );

  smatnmf( n,  n,   a, lda, &anorm );
  smatnmf( n, nrhs, b, ldb, &rnorm );
  smatnmf( n, nrhs, x, ldx, &xnorm );

  resd = rnorm / xnorm / anorm / n / eps;

  free( piv ); free( x ); free( b ); free( a );

  if (Uresd) *Uresd = resd;

  if (resd < 10.0)
    return 0;

  return 1;
}	/* tsgesv */

int
tsposv(char *uplo, int n, int nrhs, Options *opts, double *Uresd) {
  float *a, *b, *x;
  int lda, ldb, ldx, info;
  float cd_1 = 1.0f, cd_n1 = -1.0f;
  float anorm, rnorm, v, eps, resd;
  enum LFC_UpLo lfcul = ((uplo[0] == 'U' || uplo[0] == 'u') ? LFC_Upper : LFC_Lower);
  enum LFC_CBLAS_UPLO ul = ((uplo[0] == 'U' || uplo[0] == 'u') ? LFC_CblasUpper : LFC_CblasLower);

  lda = n + 10;
  ldb = n + 11;
  ldx = n + 13;

  a = malloc( (sizeof *a) * lda * n );
  b = malloc( (sizeof *b) * ldb * nrhs );
  x = malloc( (sizeof *x) * ldx * nrhs );

  assert( a && b && x );

  /* generate system matrix */
  saxbgen( n, n, a, lda, n, nrhs, b, ldb, n );

  smatcpy( n, nrhs, x, ldx, b, ldb );

  if (opts->use_file) {
    LFC_Status s; char *apath, *bpath; int i, j;
    LFC_Layout alout, blout;

    alout.row = alout.col = alout.opts =
    blout.row = blout.col = blout.opts = 0;
    alout.opts = LFC_RDONLY;

    apath = LFC_StrSubst1( "%s/chl-mat", PREFIX );
    bpath = LFC_StrSubst1( "%s/rhs-mat", PREFIX0 );

    LFC_Matrix_swrite( apath, n, n, a, lda );
    LFC_Matrix_swrite( bpath, n, nrhs, b, ldb );

    LFC_sposv_file_all( lfcul, n, nrhs, apath, &alout, bpath, &blout, &s );

    free( x );
    LFC_Matrix_sread( bpath, 0, 0, &x, &i, &j );
    ldx = i;

    remove( apath );
    remove( bpath );
  } else {
    LFC_sposv( LFC_ColMajor, lfcul, n, nrhs, a, lda, x, ldx, &info );
    if (info) printf( "LFC_sposv() -> %d\n", info );
  }

  /* regenerate system matrix */
  saxbgen( n, n, a, lda, n, nrhs, b, ldb, n );

  lfc_cblas_ssymm( LFC_CblasColMajor, LFC_CblasLeft, ul, n, nrhs, cd_n1, a,
		   lda, x, ldx, cd_1, b, ldb );

  eps = LFC_slamch( "Epsilon" );
  smatnmf( n, n,    a, lda, &anorm );
  smatnmf( n, nrhs, b, ldb, &rnorm );

  /*printf( "%g %g\n", rnorm, rnorm / anorm / n / eps );*/
  resd = rnorm / anorm / n / eps;

  free( x ); free( b ); free( a );

  if (Uresd) *Uresd = resd;

  if (resd < 10.0)
    return 0;

  return 1;
}	/* tsposv */

LFC_END_C_DECLS
