|
| 1 | +// Copyright ©2023 The Gonum Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package testlapack |
| 6 | + |
| 7 | +import ( |
| 8 | + "math" |
| 9 | + "testing" |
| 10 | + |
| 11 | + "golang.org/x/exp/rand" |
| 12 | + |
| 13 | + "gonum.org/v1/gonum/blas" |
| 14 | + "gonum.org/v1/gonum/blas/blas64" |
| 15 | + "gonum.org/v1/gonum/lapack" |
| 16 | +) |
| 17 | + |
| 18 | +type Dgghrder interface { |
| 19 | + Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) |
| 20 | +} |
| 21 | + |
| 22 | +func DgghrdTest(t *testing.T, impl Dgghrder) { |
| 23 | + const tol = 1e-13 |
| 24 | + const ldAdd = 5 |
| 25 | + rnd := rand.New(rand.NewSource(1)) |
| 26 | + comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry} |
| 27 | + for _, compq := range comps { |
| 28 | + for _, compz := range comps { |
| 29 | + for _, n := range []int{2, 0, 1, 4, 15} { |
| 30 | + ldMin := max(1, n) |
| 31 | + for _, lda := range []int{ldMin, ldMin + ldAdd} { |
| 32 | + for _, ldb := range []int{ldMin, ldMin + ldAdd} { |
| 33 | + for _, ldq := range []int{ldMin, ldMin + ldAdd} { |
| 34 | + for _, ldz := range []int{ldMin, ldMin + ldAdd} { |
| 35 | + testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz) |
| 36 | + } |
| 37 | + } |
| 38 | + } |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { |
| 46 | + a := randomGeneral(n, n, lda, rnd) |
| 47 | + b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd) |
| 48 | + var q, q1, z, z1 blas64.General |
| 49 | + if compq == lapack.OrthoEntry { |
| 50 | + q = randomOrthogonal(n, rnd) |
| 51 | + q1 = cloneGeneral(q) |
| 52 | + } else { |
| 53 | + q = nanGeneral(n, n, ldq) |
| 54 | + } |
| 55 | + if compz == lapack.OrthoEntry { |
| 56 | + z = randomOrthogonal(n, rnd) |
| 57 | + z1 = cloneGeneral(z) |
| 58 | + } else { |
| 59 | + z = nanGeneral(n, n, ldz) |
| 60 | + } |
| 61 | + |
| 62 | + hGot := cloneGeneral(a) |
| 63 | + tGot := cloneGeneral(b) |
| 64 | + for i := 1; i < n; i++ { |
| 65 | + for j := 0; j < i; j++ { |
| 66 | + // Set all lower tri elems to NaN to catch bad implementations. |
| 67 | + tGot.Data[i*tGot.Stride+j] = math.NaN() |
| 68 | + } |
| 69 | + } |
| 70 | + impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride) |
| 71 | + if n == 0 { |
| 72 | + return |
| 73 | + } |
| 74 | + if !isUpperHessenberg(hGot) { |
| 75 | + t.Error("H is not upper Hessenberg") |
| 76 | + } |
| 77 | + if !isNaNFree(tGot) || !isNaNFree(hGot) { |
| 78 | + t.Error("T or H is/or not NaN free") |
| 79 | + } |
| 80 | + if !isUpperTriangular(tGot) { |
| 81 | + t.Error("T is not upper triangular") |
| 82 | + } |
| 83 | + if compq == lapack.OrthoNone { |
| 84 | + if !isAllNaN(q.Data) { |
| 85 | + t.Errorf("Q is not NaN") |
| 86 | + } |
| 87 | + return |
| 88 | + } |
| 89 | + if compz == lapack.OrthoNone { |
| 90 | + if !isAllNaN(z.Data) { |
| 91 | + t.Errorf("Z is not NaN") |
| 92 | + } |
| 93 | + return |
| 94 | + } |
| 95 | + if compq != compz { |
| 96 | + return // Do not handle mixed case |
| 97 | + } |
| 98 | + comp := compq |
| 99 | + aux := zeros(n, n, n) |
| 100 | + |
| 101 | + switch comp { |
| 102 | + case lapack.OrthoUnit: |
| 103 | + // Qᵀ*A*Z = H |
| 104 | + hCalc := zeros(n, n, n) |
| 105 | + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux) |
| 106 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc) |
| 107 | + if !equalApproxGeneral(hGot, hCalc, tol) { |
| 108 | + t.Errorf("Qᵀ*A*Z != H") |
| 109 | + } |
| 110 | + |
| 111 | + // Qᵀ*B*Z = T |
| 112 | + tCalc := zeros(n, n, n) |
| 113 | + blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux) |
| 114 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc) |
| 115 | + if !equalApproxGeneral(hGot, hCalc, tol) { |
| 116 | + t.Errorf("Qᵀ*B*Z != T") |
| 117 | + } |
| 118 | + case lapack.OrthoEntry: |
| 119 | + // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ |
| 120 | + lhs := zeros(n, n, n) |
| 121 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux) |
| 122 | + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ |
| 123 | + |
| 124 | + rhs := zeros(n, n, n) |
| 125 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux) |
| 126 | + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) |
| 127 | + if !equalApproxGeneral(lhs, rhs, tol) { |
| 128 | + t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ") |
| 129 | + } |
| 130 | + |
| 131 | + // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ |
| 132 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux) |
| 133 | + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) |
| 134 | + |
| 135 | + blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux) |
| 136 | + blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) |
| 137 | + if !equalApproxGeneral(lhs, rhs, tol) { |
| 138 | + t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ") |
| 139 | + } |
| 140 | + } |
| 141 | +} |
0 commit comments