001: /* ///////////////////////////// P /// L /// A /// S /// M /// A /////////////////////////////// */
002: /* ///                    PLASMA computational routines (version 2.1.0)                      ///
003:  * ///                    Author: Jakub Kurzak                                               ///
004:  * ///                    Release Date: November, 15th 2009                                  ///
005:  * ///                    PLASMA is a software package provided by Univ. of Tennessee,       ///
006:  * ///                    Univ. of California Berkeley and Univ. of Colorado Denver          /// */
007: /* ///////////////////////////////////////////////////////////////////////////////////////////// */
008: #include "common.h"
009: 
010: /* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
011: // PLASMA_sgetrs - Solves a system of linear equations A * X = B, with a general N-by-N matrix A
012: // using the tile LU factorization computed by PLASMA_sgetrf.
013: 
014: /* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
015: // trans    PLASMA_enum (IN)
016: //          Intended to specify the the form of the system of equations:
017: //          = PlasmaNoTrans:   A * X = B     (No transpose)
018: //          = PlasmaTrans:     A**T * X = B  (Transpose)
019: //          = PlasmaTrans: A**T * X = B  (Conjugate transpose)
020: //          Currently only PlasmaNoTrans is supported.
021: //
022: // N        int (IN)
023: //          The order of the matrix A.  N >= 0.
024: //
025: // NRHS     int (IN)
026: //          The number of right hand sides, i.e., the number of columns of the matrix B.
027: //          NRHS >= 0.
028: //
029: // A        float* (IN)
030: //          The tile factors L and U from the factorization, computed by PLASMA_sgetrf.
031: //
032: // LDA      int (IN)
033: //          The leading dimension of the array A. LDA >= max(1,N).
034: //
035: // L        float* (IN)
036: //          Auxiliary factorization data, related to the tile L factor, computed by PLASMA_sgetrf.
037: //
038: // IPIV     int* (IN)
039: //          The pivot indices from PLASMA_sgetrf (not equivalent to LAPACK).
040: //
041: // B        float* (INOUT)
042: //          On entry, the N-by-NRHS matrix of right hand side matrix B.
043: //          On exit, the solution matrix X.
044: //
045: // LDB      int (IN)
046: //          The leading dimension of the array B. LDB >= max(1,N).
047: 
048: /* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
049: //          = 0: successful exit
050: //          < 0: if -i, the i-th argument had an illegal value
051: 
052: /* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */
053: int PLASMA_sgetrs(PLASMA_enum trans, int N, int NRHS, float *A, int LDA,
054:                   float *L, int *IPIV, float *B, int LDB)
055: {
056:     int NB, NT, NTRHS;
057:     int status;
058:     float *Abdl;
059:     float *Bbdl;
060:     float *Lbdl;
061:     plasma_context_t *plasma;
062: 
063:     plasma = plasma_context_self();
064:     if (plasma == NULL) {
065:         plasma_fatal_error("PLASMA_sgetrs", "PLASMA not initialized");
066:         return PLASMA_ERR_NOT_INITIALIZED;
067:     }
068:     /* Check input arguments */
069:     if (trans != PlasmaNoTrans) {
070:         plasma_error("PLASMA_sgetrs", "only PlasmaNoTrans supported");
071:         return PLASMA_ERR_NOT_SUPPORTED;
072:     }
073:     if (N < 0) {
074:         plasma_error("PLASMA_sgetrs", "illegal value of N");
075:         return -2;
076:     }
077:     if (NRHS < 0) {
078:         plasma_error("PLASMA_sgetrs", "illegal value of NRHS");
079:         return -3;
080:     }
081:     if (LDA < max(1, N)) {
082:         plasma_error("PLASMA_sgetrs", "illegal value of LDA");
083:         return -5;
084:     }
085:     if (LDB < max(1, N)) {
086:         plasma_error("PLASMA_sgetrs", "illegal value of LDB");
087:         return -9;
088:     }
089:     /* Quick return */
090:     if (min(N, NRHS) == 0)
091:         return PLASMA_SUCCESS;
092: 
093:     /* Tune NB & IB depending on N & NRHS; Set NBNBSIZE */
094:     status = plasma_tune(PLASMA_FUNC_SGESV, N, N, NRHS);
095:     if (status != PLASMA_SUCCESS) {
096:         plasma_error("PLASMA_sgetrs", "plasma_tune() failed");
097:         return status;
098:     }
099: 
100:     /* Set NT & NTRHS */
101:     NB = PLASMA_NB;
102:     NT = (N%NB==0) ? (N/NB) : (N/NB+1);
103:     NTRHS = (NRHS%NB==0) ? (NRHS/NB) : (NRHS/NB+1);
104: 
105:     /* Allocate memory for matrices in block layout */
106:     Abdl = (float *)plasma_shared_alloc(plasma, NT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
107:     Lbdl = (float *)plasma_shared_alloc(plasma, NT*NT*PLASMA_IBNBSIZE, PlasmaRealFloat);
108:     Bbdl = (float *)plasma_shared_alloc(plasma, NT*NTRHS*PLASMA_NBNBSIZE, PlasmaRealFloat);
109:     if (Abdl == NULL || Lbdl == NULL || Bbdl == NULL) {
110:         plasma_error("PLASMA_sgetrs", "plasma_shared_alloc() failed");
111:         plasma_shared_free(plasma, Abdl);
112:         plasma_shared_free(plasma, Lbdl);
113:         plasma_shared_free(plasma, Bbdl);
114:         return PLASMA_ERR_OUT_OF_RESOURCES;
115:     }
116: 
117:     PLASMA_desc descA = plasma_desc_init(
118:         Abdl, PlasmaRealFloat,
119:         PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
120:         N, N, 0, 0, N, N);
121: 
122:     PLASMA_desc descB = plasma_desc_init(
123:         Bbdl, PlasmaRealFloat,
124:         PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
125:         N, NRHS, 0, 0, N, NRHS);
126: 
127:     PLASMA_desc descL = plasma_desc_init(
128:         Lbdl, PlasmaRealFloat,
129:         PLASMA_IB, PLASMA_NB, PLASMA_IBNBSIZE,
130:         N, N, 0, 0, N, N);
131: 
132:     plasma_parallel_call_3(plasma_lapack_to_tile,
133:         float*, A,
134:         int, LDA,
135:         PLASMA_desc, descA);
136: 
137:     plasma_parallel_call_3(plasma_lapack_to_tile,
138:         float*, B,
139:         int, LDB,
140:         PLASMA_desc, descB);
141: 
142:     /* Receive L from the user */
143:     plasma_memcpy(Lbdl, L, NT*NT*PLASMA_IBNBSIZE, PlasmaRealFloat);
144: 
145:     /* Call the native interface */
146:     status = PLASMA_sgetrs_Tile(&descA, &descL, IPIV, &descB);
147: 
148:     if (status == PLASMA_SUCCESS)
149:         plasma_parallel_call_3(plasma_tile_to_lapack,
150:             PLASMA_desc, descB,
151:             float*, B,
152:             int, LDB);
153: 
154:     plasma_shared_free(plasma, Abdl);
155:     plasma_shared_free(plasma, Lbdl);
156:     plasma_shared_free(plasma, Bbdl);
157:     return status;
158: }
159: 
160: /* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
161: // PLASMA_sgetrs_Tile - Solves a system of linear equations A * X = B, with a general N-by-N
162: // matrix A using the tile LU factorization computed by PLASMA_sgetrf.
163: // All matrices are passed through descriptors. All dimensions are taken from the descriptors.
164: 
165: /* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
166: // A        float* (IN)
167: //          The tile factors L and U from the factorization, computed by PLASMA_sgetrf.
168: //
169: // L        float* (IN)
170: //          Auxiliary factorization data, related to the tile L factor, computed by PLASMA_sgetrf.
171: //
172: // IPIV     int* (IN)
173: //          The pivot indices from PLASMA_sgetrf (not equivalent to LAPACK).
174: //
175: // B        float* (INOUT)
176: //          On entry, the N-by-NRHS matrix of right hand side matrix B.
177: //          On exit, the solution matrix X.
178: 
179: /* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
180: //          = 0: successful exit
181: 
182: /* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */
183: int PLASMA_sgetrs_Tile(PLASMA_desc *A, PLASMA_desc *L, int *IPIV, PLASMA_desc *B)
184: {
185:     PLASMA_desc descA = *A;
186:     PLASMA_desc descL = *L;
187:     PLASMA_desc descB = *B;
188:     plasma_context_t *plasma;
189: 
190:     plasma = plasma_context_self();
191:     if (plasma == NULL) {
192:         plasma_fatal_error("PLASMA_sgetrs_Tile", "PLASMA not initialized");
193:         return PLASMA_ERR_NOT_INITIALIZED;
194:     }
195:     /* Check descriptors for correctness */
196:     if (plasma_desc_check(&descA) != PLASMA_SUCCESS) {
197:         plasma_error("PLASMA_sgetrs_Tile", "invalid first descriptor");
198:         return PLASMA_ERR_ILLEGAL_VALUE;
199:     }
200:     if (plasma_desc_check(&descL) != PLASMA_SUCCESS) {
201:         plasma_error("PLASMA_sgetrs_Tile", "invalid second descriptor");
202:         return PLASMA_ERR_ILLEGAL_VALUE;
203:     }
204:     if (plasma_desc_check(&descB) != PLASMA_SUCCESS) {
205:         plasma_error("PLASMA_sgetrs_Tile", "invalid third descriptor");
206:         return PLASMA_ERR_ILLEGAL_VALUE;
207:     }
208:     /* Check input arguments */
209:     if (descA.nb != descA.mb || descB.nb != descB.mb) {
210:         plasma_error("PLASMA_sgetrs_Tile", "only square tiles supported");
211:         return PLASMA_ERR_ILLEGAL_VALUE;
212:     }
213:     /* Quick return */
214: /*
215:     if (min(N, NRHS) == 0)
216:         return PLASMA_SUCCESS;
217: */
218:     plasma_parallel_call_4(plasma_pstrsmpl,
219:         PLASMA_desc, descA,
220:         PLASMA_desc, descB,
221:         PLASMA_desc, descL,
222:         int*, IPIV);
223: 
224:     plasma_parallel_call_7(plasma_pstrsm,
225:         PLASMA_enum, PlasmaLeft,
226:         PLASMA_enum, PlasmaUpper,
227:         PLASMA_enum, PlasmaNoTrans,
228:         PLASMA_enum, PlasmaNonUnit,
229:         float, 1.0,
230:         PLASMA_desc, descA,
231:         PLASMA_desc, descB);
232: 
233:     return PLASMA_SUCCESS;
234: }
235: