PLASMA  2.4.5
PLASMA - Parallel Linear Algebra for Scalable Multi-core Architectures
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups
core_spamm.c
Go to the documentation of this file.
1 
15 #include <cblas.h>
16 #include <lapacke.h>
17 #include "common.h"
18 
19 static int CORE_spamm_a2(int side, int trans, int uplo,
20  int M, int N, int K, int L,
21  int vi2, int vi3,
22  float *A2, int LDA2,
23  float *V, int LDV,
24  float *W, int LDW);
25 static int CORE_spamm_w(int side, int trans, int uplo,
26  int M, int N, int K, int L,
27  int vi2, int vi3,
28  float *A1, int LDA1,
29  float *A2, int LDA2,
30  float *V, int LDV,
31  float *W, int LDW);
32 
33 /***************************************************************************/
173 int
174 CORE_spamm(int op, int side, int storev,
175  int M, int N, int K, int L,
176  float *A1, int LDA1,
177  float *A2, int LDA2,
178  float *V, int LDV,
179  float *W, int LDW)
180 {
181 
182 
183  int vi2, vi3, uplo, trans, info;
184 
185  /* Check input arguments */
186  if ((op != PlasmaW) && (op != PlasmaA2)) {
187  coreblas_error(1, "Illegal value of op");
188  return -1;
189  }
190  if ((side != PlasmaLeft) && (side != PlasmaRight)) {
191  coreblas_error(2, "Illegal value of side");
192  return -2;
193  }
194  if ((storev != PlasmaColumnwise) && (storev != PlasmaRowwise)) {
195  coreblas_error(3, "Illegal value of storev");
196  return -3;
197  }
198  if (M < 0) {
199  coreblas_error(4, "Illegal value of M");
200  return -4;
201  }
202  if (N < 0) {
203  coreblas_error(5, "Illegal value of N");
204  return -5;
205  }
206  if (K < 0) {
207  coreblas_error(6, "Illegal value of K");
208  return -6;
209  }
210  if (L < 0) {
211  coreblas_error(7, "Illegal value of L");
212  return -7;
213  }
214  if (LDA1 < 0) {
215  coreblas_error(9, "Illegal value of LDA1");
216  return -9;
217  }
218  if (LDA2 < 0) {
219  coreblas_error(11, "Illegal value of LDA2");
220  return -11;
221  }
222  if (LDV < 0) {
223  coreblas_error(13, "Illegal value of LDV");
224  return -13;
225  }
226  if (LDW < 0) {
227  coreblas_error(15, "Illegal value of LDW");
228  return -15;
229  }
230 
231  /* Quick return */
232  if ((M == 0) || (N == 0) || (K == 0))
233  return PLASMA_SUCCESS;
234 
235  /*
236  * TRANS is set as:
237  *
238  * -------------------------------------
239  * side direct PlasmaW PlasmaA2
240  * -------------------------------------
241  * left colwise T N
242  * rowwise N T
243  * right colwise N T
244  * rowwise T N
245  * -------------------------------------
246  */
247 
248  /* Columnwise*/
249  if (storev == PlasmaColumnwise) {
250  uplo = CblasUpper;
251  if (side == PlasmaLeft) {
252  trans = op == PlasmaA2 ? PlasmaNoTrans : PlasmaTrans;
253  vi2 = trans == PlasmaNoTrans ? M - L : K - L;
254  }
255  else {
256  trans = op == PlasmaW ? PlasmaNoTrans : PlasmaTrans;
257  vi2 = trans == PlasmaNoTrans ? K - L : N - L;
258  }
259  vi3 = LDV * L;
260  }
261 
262  /* Rowwise */
263  else {
264  uplo = CblasLower;
265  if (side == PlasmaLeft) {
266  trans = op == PlasmaW ? PlasmaNoTrans : PlasmaTrans;
267  vi2 = trans == PlasmaNoTrans ? K - L : M - L;
268  }
269  else {
270  trans = op == PlasmaA2 ? PlasmaNoTrans : PlasmaTrans;
271  vi2 = trans == PlasmaNoTrans ? N - L : K - L;
272  }
273  vi2 *= LDV;
274  vi3 = L;
275  }
276 
277 
278  if (op==PlasmaW) {
279  info = CORE_spamm_w(
280  side, trans, uplo, M, N, K, L, vi2, vi3,
281  A1, LDA1, A2, LDA2, V, LDV, W, LDW);
282  if (info != 0)
283  return info;
284  } else if (op==PlasmaA2) {
285  info = CORE_spamm_a2(
286  side, trans, uplo, M, N, K, L, vi2, vi3,
287  A2, LDA2, V, LDV, W, LDW);
288  if (info != 0)
289  return info;
290  }
291 
292  return PLASMA_SUCCESS;
293 }
294 
295 /***************************************************************************/
296 static int
297 CORE_spamm_w(int side, int trans, int uplo,
298  int M, int N, int K, int L,
299  int vi2, int vi3,
300  float *A1, int LDA1,
301  float *A2, int LDA2,
302  float *V, int LDV,
303  float *W, int LDW)
304 {
305 
306  /*
307  * W = A1 + op(V) * A2 or W = A1 + A2 * op(V)
308  */
309 
310  int j;
311  static float zone = 1.0;
312  static float zzero = 0.0;
313 
314  if (side == PlasmaLeft) {
315 
316  if (((trans == PlasmaTrans) && (uplo == CblasUpper)) ||
317  ((trans == PlasmaNoTrans) && (uplo == CblasLower))) {
318 
319  /*
320  * W = A1 + V' * A2
321  */
322 
323  /* W = A2_2 */
324  LAPACKE_slacpy_work(LAPACK_COL_MAJOR,
326  L, N,
327  &A2[K-L], LDA2, W, LDW);
328 
329  /* W = V_2' * W + V_1' * A2_1 (ge+tr, top L rows of V') */
330  if (L > 0) {
331  /* W = V_2' * W */
332  cblas_strmm(
334  (CBLAS_TRANSPOSE)trans, CblasNonUnit, L, N,
335  (zone), &V[vi2], LDV,
336  W, LDW);
337 
338  /* W = W + V_1' * A2_1 */
339  if (K > L) {
340  cblas_sgemm(
342  L, N, K-L,
343  (zone), V, LDV,
344  A2, LDA2,
345  (zone), W, LDW);
346  }
347  }
348 
349  /* W_2 = V_3' * A2: (ge, bottom M-L rows of V') */
350  if (M > L) {
351  cblas_sgemm(
353  (M-L), N, K,
354  (zone), &V[vi3], LDV,
355  A2, LDA2,
356  (zzero), &W[L], LDW);
357  }
358 
359  /* W = A1 + W */
360  for(j = 0; j < N; j++) {
361  cblas_saxpy(
362  M, (zone),
363  &A1[LDA1*j], 1,
364  &W[LDW*j], 1);
365  }
366  }
367  else {
368  printf("Left Upper/NoTrans & Lower/ConjTrans not implemented yet\n");
370 
371  }
372  }
373  else { //side right
374 
375  if (((trans == PlasmaTrans) && (uplo == CblasUpper)) ||
376  ((trans == PlasmaNoTrans) && (uplo == CblasLower))) {
377  printf("Right Upper/ConjTrans & Lower/NoTrans not implemented yet\n");
379 
380  }
381  else {
382 
383  /*
384  * W = A1 + A2 * V
385  */
386 
387  if (L > 0) {
388 
389  /* W = A2_2 */
390  LAPACKE_slacpy_work(LAPACK_COL_MAJOR,
392  M, L,
393  &A2[LDA2*(K-L)], LDA2, W, LDW);
394 
395  /* W = W * V_2 --> W = A2_2 * V_2 */
396  cblas_strmm(
398  (CBLAS_TRANSPOSE)trans, CblasNonUnit, M, L,
399  (zone), &V[vi2], LDV,
400  W, LDW);
401 
402  /* W = W + A2_1 * V_1 */
403  if (K > L) {
404  cblas_sgemm(
406  M, L, K-L,
407  (zone), A2, LDA2,
408  V, LDV,
409  (zone), W, LDW);
410  }
411 
412  }
413 
414  /* W = W + A2 * V_3 */
415  if (N > L) {
416  cblas_sgemm(
418  M, N-L, K,
419  (zone), A2, LDA2,
420  &V[vi3], LDV,
421  (zzero), &W[LDW*L], LDW);
422  }
423 
424  /* W = A1 + W */
425  for (j = 0; j < N; j++) {
426  cblas_saxpy(
427  M, (zone),
428  &A1[LDA1*j], 1,
429  &W[LDW*j], 1);
430  }
431  }
432  }
433 
434  return PLASMA_SUCCESS;
435 }
436 
437 /***************************************************************************/
438 static int
439 CORE_spamm_a2(int side, int trans, int uplo,
440  int M, int N, int K, int L,
441  int vi2, int vi3,
442  float *A2, int LDA2,
443  float *V, int LDV,
444  float *W, int LDW)
445 {
446 
447  /*
448  * A2 = A2 + op(V) * W or A2 = A2 + W * op(V)
449  */
450 
451  int j;
452  static float zone = 1.0;
453  static float mzone = -1.0;
454 
455  if (side == PlasmaLeft) {
456 
457  if (((trans == PlasmaTrans) && (uplo == CblasUpper)) ||
458  ((trans == PlasmaNoTrans) && (uplo == CblasLower))) {
459 
460  printf("Left Upper/ConjTrans & Lower/NoTrans not implemented yet\n");
462 
463  }
464  else { //trans
465 
466  /*
467  * A2 = A2 - V * W
468  */
469 
470  /* A2_1 = A2_1 - V_1 * W_1 */
471  if (M > L) {
472  cblas_sgemm(
474  M-L, N, L,
475  (mzone), V, LDV,
476  W, LDW,
477  (zone), A2, LDA2);
478  }
479 
480  /* W_1 = V_2 * W_1 */
481  cblas_strmm(
483  (CBLAS_TRANSPOSE)trans, CblasNonUnit, L, N,
484  (zone), &V[vi2], LDV,
485  W, LDW);
486 
487  /* A2_2 = A2_2 - W_1 */
488  for(j = 0; j < N; j++) {
489  cblas_saxpy(
490  L, (mzone),
491  &W[LDW*j], 1,
492  &A2[LDA2*j+(M-L)], 1);
493  }
494 
495  /* A2 = A2 - V_3 * W_2 */
496  if (K > L) {
497  cblas_sgemm(
499  M, N, (K-L),
500  (mzone), &V[vi3], LDV,
501  &W[L], LDW,
502  (zone), A2, LDA2);
503  }
504 
505  }
506  }
507  else { //side right
508 
509  if (((trans == PlasmaTrans) && (uplo == CblasUpper)) ||
510  ((trans == PlasmaNoTrans) && (uplo == CblasLower))) {
511 
512  /*
513  * A2 = A2 - W * V'
514  */
515 
516  /* A2 = A2 - W_2 * V_3' */
517  if (K > L) {
518  cblas_sgemm(
520  M, N, K-L,
521  (mzone), &W[LDW*L], LDW,
522  &V[vi3], LDV,
523  (zone), A2, LDA2);
524  }
525 
526  /* A2_1 = A2_1 - W_1 * V_1' */
527  if (N > L) {
528  cblas_sgemm(
530  M, N-L, L,
531  (mzone), W, LDW,
532  V, LDV,
533  (zone), A2, LDA2);
534  }
535 
536  /* A2_2 = A2_2 - W_1 * V_2' */
537  if (L > 0) {
538  cblas_strmm(
540  (CBLAS_TRANSPOSE)trans, CblasNonUnit, M, L,
541  (mzone), &V[vi2], LDV,
542  W, LDW);
543 
544  for (j = 0; j < L; j++) {
545  cblas_saxpy(
546  M, (zone),
547  &W[LDW*j], 1,
548  &A2[LDA2*(N-L+j)], 1);
549  }
550  }
551 
552  }
553  else {
554  printf("Right Upper/NoTrans & Lower/ConjTrans not implemented yet\n");
556  }
557  }
558 
559  return PLASMA_SUCCESS;
560 }
561 
562 /***************************************************************************/
563 
564 
565 /***************************************************************************/
568 void
570  int op, int side, int storev,
571  int m, int n, int k, int l,
572  float *A1, int lda1,
573  float *A2, int lda2,
574  float *V, int ldv,
575  float *W, int ldw)
576 {
577  QUARK_Insert_Task(quark, CORE_spamm_quark, task_flags,
578  sizeof(int), &op, VALUE,
579  sizeof(PLASMA_enum), &side, VALUE,
580  sizeof(PLASMA_enum), &storev, VALUE,
581  sizeof(int), &m, VALUE,
582  sizeof(int), &n, VALUE,
583  sizeof(int), &k, VALUE,
584  sizeof(int), &l, VALUE,
585  sizeof(float)*m*k, A1, INPUT,
586  sizeof(int), &lda1, VALUE,
587  sizeof(float)*k*n, A2, INPUT,
588  sizeof(int), &lda2, VALUE,
589  sizeof(float)*m*n, V, INPUT,
590  sizeof(int), &ldv, VALUE,
591  sizeof(float)*m*n, W, INOUT,
592  sizeof(int), &ldw, VALUE,
593  0);
594 }
595 
596 /***************************************************************************/
599 void
601 {
602  int op;
603  int side;
604  int storev;
605  int M;
606  int N;
607  int K;
608  int L;
609  float *A1;
610  int LDA1;
611  float *A2;
612  int LDA2;
613  float *V;
614  int LDV;
615  float *W;
616  int LDW;
617 
618  quark_unpack_args_15(quark, op, side, storev, M, N, K, L,
619  A1, LDA1, A2, LDA2, V, LDV, W, LDW);
620 
621  CORE_spamm( op, side, storev, M, N, K, L, A1, LDA1, A2, LDA2, V, LDV, W, LDW);
622 }
623 
624 /***************************************************************************/