#include "relapack.h" #include static void RELAPACK_stgsyl_rec(const char *, const blasint *, const blasint *, const blasint *, const float *, const blasint *, const float *, const blasint *, float *, const blasint *, const float *, const blasint *, const float *, const blasint *, float *, const blasint *, float *, float *, float *, blasint *, blasint *, blasint *); /** STGSYL solves the generalized Sylvester equation. * * This routine is functionally equivalent to LAPACK's stgsyl. * For details on its interface, see * http://www.netlib.org/lapack/explore-html/dc/d67/stgsyl_8f.html * */ void RELAPACK_stgsyl( const char *trans, const blasint *ijob, const blasint *m, const blasint *n, const float *A, const blasint *ldA, const float *B, const blasint *ldB, float *C, const blasint *ldC, const float *D, const blasint *ldD, const float *E, const blasint *ldE, float *F, const blasint *ldF, float *scale, float *dif, float *Work, const blasint *lWork, blasint *iWork, blasint *info ) { // Parse arguments const blasint notran = LAPACK(lsame)(trans, "N"); const blasint tran = LAPACK(lsame)(trans, "T"); // Compute work buffer size blasint lwmin = 1; if (notran && (*ijob == 1 || *ijob == 2)) lwmin = MAX(1, 2 * *m * *n); *info = 0; // Check arguments if (!tran && !notran) *info = -1; else if (notran && (*ijob < 0 || *ijob > 4)) *info = -2; else if (*m <= 0) *info = -3; else if (*n <= 0) *info = -4; else if (*ldA < MAX(1, *m)) *info = -6; else if (*ldB < MAX(1, *n)) *info = -8; else if (*ldC < MAX(1, *m)) *info = -10; else if (*ldD < MAX(1, *m)) *info = -12; else if (*ldE < MAX(1, *n)) *info = -14; else if (*ldF < MAX(1, *m)) *info = -16; else if (*lWork < lwmin && *lWork != -1) *info = -20; if (*info) { const blasint minfo = -*info; LAPACK(xerbla)("STGSYL", &minfo, strlen("STGSYL")); return; } if (*lWork == -1) { // Work size query *Work = lwmin; return; } // Clean char * arguments const char cleantrans = notran ? 'N' : 'T'; // Constant const float ZERO[] = { 0. }; blasint isolve = 1; blasint ifunc = 0; if (notran) { if (*ijob >= 3) { ifunc = *ijob - 2; LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC); LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF); } else if (*ijob >= 1) isolve = 2; } float scale2; blasint iround; for (iround = 1; iround <= isolve; iround++) { *scale = 1; float dscale = 0; float dsum = 1; blasint pq; RELAPACK_stgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, iWork, &pq, info); if (dscale != 0) { if (*ijob == 1 || *ijob == 3) *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum)); else *dif = sqrt(pq) / (dscale * sqrt(dsum)); } if (isolve == 2) { if (iround == 1) { if (notran) ifunc = *ijob; scale2 = *scale; LAPACK(slacpy)("F", m, n, C, ldC, Work, m); LAPACK(slacpy)("F", m, n, F, ldF, Work + *m * *n, m); LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC); LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF); } else { LAPACK(slacpy)("F", m, n, Work, m, C, ldC); LAPACK(slacpy)("F", m, n, Work + *m * *n, m, F, ldF); *scale = scale2; } } } } /** stgsyl's recursive vompute kernel */ static void RELAPACK_stgsyl_rec( const char *trans, const blasint *ifunc, const blasint *m, const blasint *n, const float *A, const blasint *ldA, const float *B, const blasint *ldB, float *C, const blasint *ldC, const float *D, const blasint *ldD, const float *E, const blasint *ldE, float *F, const blasint *ldF, float *scale, float *dsum, float *dscale, blasint *iWork, blasint *pq, blasint *info ) { if (*m <= MAX(CROSSOVER_STGSYL, 1) && *n <= MAX(CROSSOVER_STGSYL, 1)) { // Unblocked LAPACK(stgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, iWork, pq, info); return; } // Constants const float ONE[] = { 1. }; const float MONE[] = { -1. }; const blasint iONE[] = { 1 }; // Outputs float scale1[] = { 1. }; float scale2[] = { 1. }; blasint info1[] = { 0 }; blasint info2[] = { 0 }; if (*m > *n) { // Splitting blasint m1 = SREC_SPLIT(*m); if (A[m1 + *ldA * (m1 - 1)]) m1++; const blasint m2 = *m - m1; // A_TL A_TR // 0 A_BR const float *const A_TL = A; const float *const A_TR = A + *ldA * m1; const float *const A_BR = A + *ldA * m1 + m1; // C_T // C_B float *const C_T = C; float *const C_B = C + m1; // D_TL D_TR // 0 D_BR const float *const D_TL = D; const float *const D_TR = D + *ldD * m1; const float *const D_BR = D + *ldD * m1 + m1; // F_T // F_B float *const F_T = F; float *const F_B = F + m1; if (*trans == 'N') { // recursion(A_BR, B, C_B, D_BR, E, F_B) RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, iWork, pq, info1); // C_T = C_T - A_TR * C_B BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC); // F_T = F_T - D_TR * C_B BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF); // recursion(A_TL, B, C_T, D_TL, E, F_T) RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, iWork, pq, info2); // apply scale if (scale2[0] != 1) { LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info); LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info); } } else { // recursion(A_TL, B, C_T, D_TL, E, F_T) RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, iWork, pq, info1); // apply scale if (scale1[0] != 1) LAPACK(slascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info); // C_B = C_B - A_TR^H * C_T BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC); // C_B = C_B - D_TR^H * F_T BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC); // recursion(A_BR, B, C_B, D_BR, E, F_B) RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, iWork, pq, info2); // apply scale if (scale2[0] != 1) { LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info); LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info); } } } else { // Splitting blasint n1 = SREC_SPLIT(*n); if (B[n1 + *ldB * (n1 - 1)]) n1++; const blasint n2 = *n - n1; // B_TL B_TR // 0 B_BR const float *const B_TL = B; const float *const B_TR = B + *ldB * n1; const float *const B_BR = B + *ldB * n1 + n1; // C_L C_R float *const C_L = C; float *const C_R = C + *ldC * n1; // E_TL E_TR // 0 E_BR const float *const E_TL = E; const float *const E_TR = E + *ldE * n1; const float *const E_BR = E + *ldE * n1 + n1; // F_L F_R float *const F_L = F; float *const F_R = F + *ldF * n1; if (*trans == 'N') { // recursion(A, B_TL, C_L, D, E_TL, F_L) RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, iWork, pq, info1); // C_R = C_R + F_L * B_TR BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC); // F_R = F_R + F_L * E_TR BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF); // recursion(A, B_BR, C_R, D, E_BR, F_R) RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, iWork, pq, info2); // apply scale if (scale2[0] != 1) { LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info); LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info); } } else { // recursion(A, B_BR, C_R, D, E_BR, F_R) RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, iWork, pq, info1); // apply scale if (scale1[0] != 1) LAPACK(slascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info); // F_L = F_L + C_R * B_TR BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF); // F_L = F_L + F_R * E_TR BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF); // recursion(A, B_TL, C_L, D, E_TL, F_L) RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, iWork, pq, info2); // apply scale if (scale2[0] != 1) { LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info); LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info); } } } *scale = scale1[0] * scale2[0]; *info = info1[0] || info2[0]; }