/* -*- C -*- */
/* dusrlpck.c
 */

#include <lfc/lfci.h>

#define TSTBIT(v,ib) (((v)&(ib))==(ib))

LFC_BEGIN_C_DECLS

static int 
LFC_dgesv_file_all_1x1(int n, int nrhs, const char *apath, int *piv,
  const char *bpath, const char *luf_fname, const char *sol_fname,
  const char *nul_fname, LFC_Status *status) {

  int mm, nn, linfo, rv;
  double *a, *b;

  rv = LFC_Matrix_dread( apath, 0, 0, &a, &mm, &nn );
  if (rv) return LFC_FAILURE;
  if (mm != n || nn != n) {
    LFC_FREE( a );
    return LFC_FAILURE;
  }

  rv = LFC_Matrix_dread( bpath, 0, 0, &b, &mm, &nn );
  if (rv) {
    LFC_FREE( a );
    return LFC_FAILURE;
  }
  if (mm != n || nn != nrhs) {
    LFC_FREE( b );
    LFC_FREE( a );
    return LFC_FAILURE;
  }

  rv = LFC_dgesv( LFC_ColMajor, n, nrhs, a, n, piv, b, n, &linfo);
  status->info = linfo;

  if (luf_fname != nul_fname) LFC_Matrix_dwrite( luf_fname, n, n, a, n );
  if (sol_fname != nul_fname) LFC_Matrix_dwrite( sol_fname, n, nrhs, b, n );

  LFC_FREE( b );
  LFC_FREE( a );

  return rv;
}	/* LFC_dgesv_file_all_1x1 */

static int 
LFC_dposv_file_all_1x1(enum LFC_UpLo uplo, int n, int nrhs, const char *apath,
  const char *bpath, const char *llt_fname, const char *sol_fname,
  const char *nul_fname, int opcode, LFC_Status *status) {
  fprintf( stderr, "LFC_dposv_file_all_1x1\n" );
  return LFC_FAILURE;
}	/* LFC_dposv_file_all_1x1 */

static int 
LFC_dgels_file_all_1x1(enum LFC_Transpose trans, int m, int n, int nrhs,
  const char *apath, const char *bpath, const char *llt_fname,
  const char *sol_fname, const char *nul_fname, int opcode, LFC_Status *status) {
  fprintf( stderr, "LFC_dgels_file_all_1x1\n" );
  return LFC_FAILURE;
}	/* LFC_dgels_file_all_1x1 */

static int
ditchMF(char *machinefile) {
  if (machinefile) {
    remove( machinefile );
    LFC_FREE( machinefile );
  }
  return 0;
}

int
LFC_dgesv_file_all_x(int n, int nrhs, char *apath, int aopts, int *piv,
  char *bpath, int bopts, int gp, int gq, LFC_Status *ustatus) {
  char *mat_fname, *luf_fname, *rhs_fname, *sol_fname, **argv, *machinefile;
  char nul_fname[] = "/../../";
  int rv;
  char tmat = e2chDtype(LFC_DOUBLE);

  LFC_Status *status, lstatus;
  if (ustatus) status = ustatus;
  else status = &lstatus;
  status->info = -1;

  mat_fname = apath; /* linear system matrix */
  rhs_fname = bpath; /* right-hand side matrix */

  if (TSTBIT( aopts, LFC_RDONLY)) {
    luf_fname = nul_fname;
  } else {
    luf_fname = apath; }
  if (TSTBIT( bopts, LFC_RDONLY)) {
    sol_fname = nul_fname;
  } else {
    sol_fname = bpath; }

  rv = LFC_mfile( gp, gq, &machinefile );
  if (LFC_FAILURE == rv) {
    return LFC_FAILURE;
  }
  rv = LFC_argv_mpi( tmat, mat_fname, tmat, rhs_fname, tmat, sol_fname, tmat, luf_fname, machinefile,
		     n, n, nrhs, 80, gp, gq, 1, NULL, &argv );
  if (rv) {
    ditchMF( machinefile );
    return LFC_FAILURE;
  }

  rv = LFC_run_mpi( argv );

  LFC_argvDelete( &argv );
  ditchMF( machinefile );

  if (rv) return LFC_FAILURE;

  status->info = 0;
  return LFC_SUCCESS;
}	/* LFC_dgesv_file_all_x */

static int
LFC_dsolve_file_all(int m, int n, int nrhs, /*const*/ char *apath,
  LFC_Layout *alout, int *piv, /*const*/ char *bpath,
  LFC_Layout *blout, int opcode, enum LFC_Transpose trans,
  enum LFC_UpLo uplo, LFC_Status *ustatus) {
  char *mat_fname, *dmt_fname, *rhs_fname, *sol_fname, **argv, *machinefile;
  char nul_fname[] = "/../../";
  char tmat = e2chDtype(LFC_DOUBLE);
  int rv, gp, gq;
  char *addargv[3], s_uplo[] = "uplo=U", s_trans[] = "trans=N";
  LFC_Status *status, lstatus;

  if (ustatus) status = ustatus;
  else status = &lstatus;

  status->info = -1;

  mat_fname = apath; /* linear system matrix */
  rhs_fname = bpath; /* right-hand side matrix */

  if (TSTBIT( alout->opts, LFC_RDONLY)) {
    dmt_fname = nul_fname;
  } else {
    dmt_fname = apath;
  }

  if (TSTBIT( blout->opts, LFC_RDONLY)) {
    sol_fname = nul_fname;
  } else {
    sol_fname = bpath;
  }

  LFC_schedule( m, n, nrhs, sizeof(double), &gp, &gq, &machinefile );

  if (gp == 1 && gq == 1) {
    ditchMF( machinefile );
    switch (opcode) {
      case 1:
	return LFC_dgesv_file_all_1x1( n, nrhs, apath, piv, bpath, dmt_fname,
				       sol_fname, nul_fname, status );
      case 3:
	return LFC_dposv_file_all_1x1( uplo, n, nrhs, apath, bpath, dmt_fname,
				       sol_fname, nul_fname, opcode, status );
      case 5:
	return LFC_dgels_file_all_1x1( trans, m, n, nrhs, apath, bpath,
				       dmt_fname, sol_fname, nul_fname, opcode,
				       status );
    }
  }

  /* FIXME: return pivot */

  s_uplo[5]  = e2chUpLo(uplo);
  s_trans[6] = e2chTrans(trans);
  addargv[0] = s_uplo;
  addargv[1] = s_trans;
  addargv[2] = NULL;

  rv = LFC_argv_mpi( tmat, mat_fname, tmat, rhs_fname, tmat, sol_fname, tmat, dmt_fname, machinefile,
		     m, n, nrhs, 80, gp, gq, opcode, addargv, &argv );

  if (rv) {
    ditchMF( machinefile );
    return LFC_FAILURE;
  }

  rv = LFC_run_mpi( argv );

  LFC_argvDelete( &argv );
  ditchMF( machinefile );

  if (rv) return LFC_FAILURE;

  status->info = 0;
  return LFC_SUCCESS;
}	/* LFC_dsolve_file_all */

int
LFC_dgesv_file_all(int n, int nrhs, /*const*/ char *apath,
  LFC_Layout *alout, int *piv, /*const*/ char *bpath,
  LFC_Layout *blout, LFC_Status *ustatus) {
  return LFC_dsolve_file_all( n, n, nrhs, apath, alout, piv, bpath, blout, 1,
			      LFC_NoTrans, LFC_Upper, ustatus );
}	/* LFC_dgesv_file_all */

int
LFC_dposv_file_all(enum LFC_UpLo uplo, int n, int nrhs, /*const*/ char *apath,
  LFC_Layout *alout, /*const*/ char *bpath, LFC_Layout *blout,
  LFC_Status *ustatus) {
  return LFC_dsolve_file_all( n, n, nrhs, apath, alout, NULL, bpath, blout, 3,
			      LFC_NoTrans, uplo, ustatus );
}	/* LFC_dposv_file_all */

int
LFC_dgels_file_all(enum LFC_Transpose trans, int m, int n, int nrhs,
  /*const*/ char *apath, LFC_Layout *alout, /*const*/ char *bpath,
  LFC_Layout *blout, LFC_Status *ustatus) {
  return LFC_dsolve_file_all( m, n, nrhs, apath, alout, NULL, bpath, blout, 5,
			      trans, LFC_Upper, ustatus );
}	/* LFC_dgels_file_all */

static int
BLAS_dgemmNN(int m, int n, int k, double alpha, const double *a, int lda,
	     const double *b, int ldb, double beta, double *c, int ldc) {
  int i, j, l;
  double temp;

  if (m < 0) return -4;
  if (n < 0) return -5;
  if (k < 0) return -6;
  if (alpha != alpha) return -7; /* NaN value */
  if (! a) return -8;
  if (lda < m) return -9;
  if (! b) return -10;
  if (ldb < k) return -11;
  if (beta != beta) return -12; /* NaN value */
  if (! c) return -13;
  if (ldc < m) return -14;

  if (1.0 != beta) return -9; /* not supported */

  for (j = 0; j < n; j++) {
    for (l = 0; l < k; l++) {
      temp = alpha * b[l + j * ldb];
      if (temp != 0.0 && temp != -0.0) {
	for (i = 0; i < m; i++) {
	  c[i + j * ldc] += temp * a[i + l * lda];
	}
      }
    }
  }

  return 0;
}	/* BLAS_dgemmNN */

static int
LFC_dgesv_rsdl_1x1(int n, int nrhs, const char *apath, const char *xpath,
		   const char *bpath, LFC_Status *ustatus) {
  int i, j;
  int rv, mm, nn, lda = n, ldb = n;
  double *a, *b, *x;
  double anrm, rnrm, deps = 1e-16, sres;

  rv = LFC_Matrix_dread( apath, 0, 0, &a, &mm, &nn );
  if (rv) return LFC_FAILURE;
  if (mm != n || nn != n) {
    LFC_FREE( a );
    return LFC_FAILURE;
  }

  rv = LFC_Matrix_dread( xpath, 0, 0, &x, &mm, &nn );
  if (rv) {
    LFC_FREE( a );
    return LFC_FAILURE;
  }
  if (mm != n || nn != nrhs) {
    LFC_FREE( x );
    LFC_FREE( a );
    return LFC_FAILURE;
  }

  rv = LFC_Matrix_dread( bpath, 0, 0, &b, &mm, &nn );
  if (rv) {
    LFC_FREE( x );
    LFC_FREE( a );
    return LFC_FAILURE;
  }
  if (mm != n || nn != nrhs) {
    LFC_FREE( b );
    LFC_FREE( x );
    LFC_FREE( a );
    return LFC_FAILURE;
  }

  BLAS_dgemmNN( n, nrhs, n, -1.0, a, lda, x, ldb, 1.0, b, ldb );

  rnrm = 0.0;
  for (j = 0; j < nrhs; j++)
    for (i = 0; i < n; i++)
      rnrm += fabs( b[i + j*ldb] );

  anrm = 0.0;
  for (j = 0; j < n; j++)
    for (i = 0; i < n; i++)
      anrm += fabs( a[i + j*ldb] );

  fprintf( stderr, "||A(%dx%d)||=%g\n", n, n, anrm );
  fprintf( stderr, "||A(%dx%d)x-b||=%g\n", n, n, rnrm );
  fprintf( stderr, "epsilon=%g\n", deps );

  sres = rnrm / anrm / deps / n;

  fprintf( stderr, "||Ax-b||/(||A|| n epsilon)=%g\n", sres );

  if (sres > 10.0) rv = 1;

  LFC_FREE( b );
  LFC_FREE( x );
  LFC_FREE( a );

  return LFC_SUCCESS;
}	/* LFC_dgesv_rsdl_1x1 */

int
LFC_dgesv_rsdl(int n, int nrhs, /*const*/ char *apath, /*const*/ char *xpath,
	       /*const*/ char *bpath, LFC_Status *ustatus) {

  int rv; char luf_fname[] = "/../../", **argv; int gp, gq; char *machinefile;
  char tmat = e2chDtype(LFC_DOUBLE);
  LFC_Status *status, lstatus;

  if (ustatus) status = ustatus;
  else status = &lstatus;

  status->info = -1;

  LFC_schedule( n, n, nrhs, sizeof(double), &gp, &gq, &machinefile );
  /* printf( "%d %d\n", gp, gq ); return 0; */

  if (1 == gq && 1 == gq) {
    ditchMF( machinefile );
    return LFC_dgesv_rsdl_1x1( n, nrhs, apath, xpath, bpath, ustatus );
  }

  rv = LFC_argv_mpi( tmat, apath, tmat, bpath, tmat, xpath, tmat, luf_fname, machinefile, n, n, nrhs, 80,
		     gp, gq, 2, NULL, &argv );
  if (rv) {
    ditchMF( machinefile );
    return LFC_FAILURE;
  }

  rv = LFC_run_mpi( argv );
  ditchMF( machinefile );
  LFC_argvDelete( &argv );

  if (rv) return LFC_FAILURE;

  status->info = 0;
  return LFC_SUCCESS;
}	/* LFC_dgesv_rsdl */

int
LFC_dgesv_rsdl_x(int n, int nrhs, char *apath, char *xpath, char *bpath,
  int gp, int gq, LFC_Status *ustatus) {

  char tmat = e2chDtype(LFC_DOUBLE);
  int rv; char luf_fname[] = "/../../", **argv; char *machinefile;
  LFC_Status *status, lstatus;

  if (ustatus) status = ustatus;
  else status = &lstatus;

  status->info = -1;

  rv = LFC_mfile( gp, gq, &machinefile );
  if (LFC_FAILURE == rv) {
    return LFC_FAILURE;
  }

  rv = LFC_argv_mpi( tmat, apath, tmat, bpath, tmat, xpath, tmat, luf_fname, machinefile, n, n, nrhs, 80,
		     gp, gq, 2, NULL, &argv );
  if (rv) {
    ditchMF( machinefile );
    return LFC_FAILURE;
  }

  rv = LFC_run_mpi( argv );
  ditchMF( machinefile );
  LFC_argvDelete( &argv );

  if (rv) return LFC_FAILURE;

  status->info = 0;
  return LFC_SUCCESS;
}	/* LFC_dgesv_rsdl */

int
LFC_dgesv(enum LFC_Order order, int n, int nrhs, double *a, int lda, int *piv, double *b, int ldb,
  int *info) {
  int i, linfo = 0;

  if (order != LFC_ColMajor) {
    if (info) *info = -1;
    return LFC_FAILURE;
  }

  i = LFC_linear_solve( LFC_GESV, LFC_Upper, LFC_NoTrans, LFC_DOUBLE,
                        n, n, nrhs, a, lda, piv, b, ldb, &linfo );

  if (info) *info = linfo;

  return i;
}	/* LFC_dgesv */

int
LFC_dposv(enum LFC_Order order, enum LFC_UpLo uplo, int n, int nrhs, double *a, int lda,
  double *b, int ldb, int *info) {
  int i, j, linfo = 0;

  if (order != LFC_ColMajor) {
    if (info) *info = -1;
    return LFC_FAILURE;
  }

  i = LFC_linear_solve( LFC_POSV, uplo, LFC_NoTrans, LFC_DOUBLE,
		      n, n, nrhs, a, lda, &j, b, ldb, &linfo );

  if (info) *info = linfo;

  return i;
}	/* LFC_dposv */

int
LFC_dgels(enum LFC_Order order, enum LFC_Transpose trans, int m, int n, int nrhs, double *a, int lda,
  double *b, int ldb, double *Uwork, int *Ulwork, int *info) {
  int i, j, linfo = 0;

  if (order != LFC_ColMajor) {
    if (info) *info = -1;
    return LFC_FAILURE;
  }

  i = LFC_linear_solve( LFC_GELS, LFC_Upper, trans, LFC_DOUBLE,
		      m, n, nrhs, a, lda, &j, b, ldb, &linfo );

  if (info) *info = linfo;

  return i;
}	/* LFC_dgels */

LFC_END_C_DECLS
