#include "relapack.h" #if XSYTRF_ALLOW_MALLOC #include #endif static void RELAPACK_ssytrf_rec(const char *, const blasint *, const blasint *, blasint *, float *, const blasint *, blasint *, float *, const blasint *, blasint *); /** SSYTRF computes the factorization of a complex symmetric matrix A using the Bunch-Kaufman diagonal pivoting method. * * This routine is functionally equivalent to LAPACK's ssytrf. * For details on its interface, see * http://www.netlib.org/lapack/explore-html/da/de9/ssytrf_8f.html * */ void RELAPACK_ssytrf( const char *uplo, const blasint *n, float *A, const blasint *ldA, blasint *ipiv, float *Work, const blasint *lWork, blasint *info ) { // Required work size const blasint cleanlWork = *n * (*n / 2); blasint minlWork = cleanlWork; #if XSYTRF_ALLOW_MALLOC minlWork = 1; #endif // Check arguments const blasint lower = LAPACK(lsame)(uplo, "L"); const blasint upper = LAPACK(lsame)(uplo, "U"); *info = 0; if (!lower && !upper) *info = -1; else if (*n < 0) *info = -2; else if (*ldA < MAX(1, *n)) *info = -4; else if ((*lWork <1 || *lWork < minlWork) && *lWork != -1) *info = -7; else if (*lWork == -1) { // Work size query *Work = cleanlWork; return; } // Ensure Work size float *cleanWork = Work; #if XSYTRF_ALLOW_MALLOC if (!*info && *lWork < cleanlWork) { cleanWork = malloc(cleanlWork * sizeof(float)); if (!cleanWork) *info = -7; } #endif if (*info) { const blasint minfo = -*info; LAPACK(xerbla)("SSYTRF", &minfo, strlen("SSYTRF")); return; } // Clean char * arguments const char cleanuplo = lower ? 'L' : 'U'; // Dummy arguments blasint nout; // Recursive kernel if (*n != 0) RELAPACK_ssytrf_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info); #if XSYTRF_ALLOW_MALLOC if (cleanWork != Work) free(cleanWork); #endif } /** ssytrf's recursive compute kernel */ static void RELAPACK_ssytrf_rec( const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out, float *A, const blasint *ldA, blasint *ipiv, float *Work, const blasint *ldWork, blasint *info ) { // top recursion level? const blasint top = *n_full == *n; if (*n <= MAX(CROSSOVER_SSYTRF, 3)) { // Unblocked if (top) { LAPACK(ssytf2)(uplo, n, A, ldA, ipiv, info); *n_out = *n; } else RELAPACK_ssytrf_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info); return; } blasint info1, info2; // Constants const float ONE[] = { 1. }; const float MONE[] = { -1. }; const blasint iONE[] = { 1 }; // Loop iterator blasint i; const blasint n_rest = *n_full - *n; if (*uplo == 'L') { // Splitting (setup) blasint n1 = SREC_SPLIT(*n); blasint n2 = *n - n1; // Work_L * float *const Work_L = Work; // recursion(A_L) blasint n1_out; RELAPACK_ssytrf_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1); n1 = n1_out; // Splitting (continued) n2 = *n - n1; const blasint n_full2 = *n_full - n1; // * * // A_BL A_BR // A_BL_B A_BR_B float *const A_BL = A + n1; float *const A_BR = A + *ldA * n1 + n1; float *const A_BL_B = A + *n; float *const A_BR_B = A + *ldA * n1 + *n; // * * // Work_BL Work_BR // * * // (top recursion level: use Work as Work_BR) float *const Work_BL = Work + n1; float *const Work_BR = top ? Work : Work + *ldWork * n1 + n1; const blasint ldWork_BR = top ? n2 : *ldWork; // ipiv_T // ipiv_B blasint *const ipiv_B = ipiv + n1; // A_BR = A_BR - A_BL Work_BL' RELAPACK_sgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA); BLAS(sgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA); // recursion(A_BR) blasint n2_out; RELAPACK_ssytrf_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2); if (n2_out != n2) { // undo 1 column of updates const blasint n_restp1 = n_rest + 1; // last column of A_BR float *const A_BR_r = A_BR + *ldA * n2_out + n2_out; // last row of A_BL float *const A_BL_b = A_BL + n2_out; // last row of Work_BL float *const Work_BL_b = Work_BL + n2_out; // A_BR_r = A_BR_r + A_BL_b Work_BL_b' BLAS(sgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE); } n2 = n2_out; // shift pivots for (i = 0; i < n2; i++) if (ipiv_B[i] > 0) ipiv_B[i] += n1; else ipiv_B[i] -= n1; *info = info1 || info2; *n_out = n1 + n2; } else { // Splitting (setup) blasint n2 = SREC_SPLIT(*n); blasint n1 = *n - n2; // * Work_R // (top recursion level: use Work as Work_R) float *const Work_R = top ? Work : Work + *ldWork * n1; // recursion(A_R) blasint n2_out; RELAPACK_ssytrf_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2); const blasint n2_diff = n2 - n2_out; n2 = n2_out; // Splitting (continued) n1 = *n - n2; const blasint n_full1 = *n_full - n2; // * A_TL_T A_TR_T // * A_TL A_TR // * * * float *const A_TL_T = A + *ldA * n_rest; float *const A_TR_T = A + *ldA * (n_rest + n1); float *const A_TL = A + *ldA * n_rest + n_rest; float *const A_TR = A + *ldA * (n_rest + n1) + n_rest; // Work_L * // * Work_TR // * * // (top recursion level: Work_R was Work) float *const Work_L = Work; float *const Work_TR = Work + *ldWork * (top ? n2_diff : n1) + n_rest; const blasint ldWork_L = top ? n1 : *ldWork; // A_TL = A_TL - A_TR Work_TR' RELAPACK_sgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA); BLAS(sgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA); // recursion(A_TL) blasint n1_out; RELAPACK_ssytrf_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1); if (n1_out != n1) { // undo 1 column of updates const blasint n_restp1 = n_rest + 1; // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t' BLAS(sgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE); } n1 = n1_out; *info = info2 || info1; *n_out = n1 + n2; } }