Skip to content

Commit f0a57a4

Browse files
authored
lapack/gonum: add Dgghrd and its test
1 parent 7bed099 commit f0a57a4

File tree

7 files changed

+296
-2
lines changed

7 files changed

+296
-2
lines changed

blas/blas64/blas64.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func Use(b blas.Float64) {
2020

2121
// Implementation returns the current BLAS float64 implementation.
2222
//
23-
// Implementation allows direct calls to the current the BLAS float64 implementation
23+
// Implementation allows direct calls to the current BLAS float64 implementation
2424
// giving finer control of parameters.
2525
func Implementation() blas.Float64 {
2626
return blas64

lapack/gonum/dgghrd.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright ©2023 The Gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package gonum
6+
7+
import (
8+
"gonum.org/v1/gonum/blas"
9+
"gonum.org/v1/gonum/blas/blas64"
10+
"gonum.org/v1/gonum/lapack"
11+
)
12+
13+
// Dgghrd reduces a pair of real matrices (A,B) to generalized upper
14+
// Hessenberg form using orthogonal transformations, where A is a
15+
// general matrix and B is upper triangular. The form of the
16+
// generalized eigenvalue problem is
17+
//
18+
// A*x = lambda*B*x,
19+
//
20+
// and B is typically made upper triangular by computing its QR
21+
// factorization and moving the orthogonal matrix Q to the left side
22+
// of the equation.
23+
// This subroutine simultaneously reduces A to a Hessenberg matrix H:
24+
//
25+
// Qᵀ*A*Z = H
26+
//
27+
// and transforms B to another upper triangular matrix T:
28+
//
29+
// Qᵀ*B*Z = T
30+
//
31+
// in order to reduce the problem to its standard form
32+
//
33+
// H*y = lambda*T*y
34+
//
35+
// where y = Zᵀ*x.
36+
//
37+
// The orthogonal matrices Q and Z are determined as products of Givens
38+
// rotations. They may either be formed explicitly, or they may be
39+
// postmultiplied into input matrices Q1 and Z1, so that
40+
//
41+
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
42+
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
43+
//
44+
// If Q1 is the orthogonal matrix from the QR factorization of B in the
45+
// original equation A*x = lambda*B*x, then Dgghrd reduces the original
46+
// problem to generalized Hessenberg form.
47+
//
48+
// Dgghrd is an internal routine. It is exported for testing purposes.
49+
func (impl Implementation) Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) {
50+
switch {
51+
case compq != lapack.OrthoNone && compq != lapack.OrthoEntry && compq != lapack.OrthoUnit:
52+
panic(badOrthoComp)
53+
case compz != lapack.OrthoNone && compz != lapack.OrthoEntry && compz != lapack.OrthoUnit:
54+
panic(badOrthoComp)
55+
case len(a) < (n-1)*lda+n:
56+
panic(shortA)
57+
case len(b) < (n-1)*ldb+n:
58+
panic(shortB)
59+
case n < 0:
60+
panic(nLT0)
61+
case ilo < 0:
62+
panic(badIlo)
63+
case ihi < ilo-1 || ihi >= n:
64+
panic(badIhi)
65+
case lda < max(1, n):
66+
panic(badLdA)
67+
case ldb < max(1, n):
68+
panic(badLdB)
69+
case (compq != lapack.OrthoNone && ldq < n) || ldq < 1:
70+
panic(badLdQ)
71+
case (compz != lapack.OrthoNone && ldz < n) || ldz < 1:
72+
panic(badLdZ)
73+
case compq != lapack.OrthoNone && len(q) < (n-1)*ldq+n:
74+
panic(shortQ)
75+
case compz != lapack.OrthoNone && len(z) < (n-1)*ldz+n:
76+
panic(shortZ)
77+
}
78+
79+
if compq == lapack.OrthoUnit {
80+
impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
81+
}
82+
if compz == lapack.OrthoUnit {
83+
impl.Dlaset(blas.All, n, n, 0, 1, z, ldz)
84+
}
85+
if n <= 1 {
86+
return // Quick return if possible.
87+
}
88+
89+
// Zero out lower triangle of B.
90+
for i := 1; i < n; i++ {
91+
for j := 0; j < i; j++ {
92+
b[i*ldb+j] = 0
93+
}
94+
}
95+
bi := blas64.Implementation()
96+
// Reduce A and B.
97+
for jcol := ilo; jcol <= ihi-2; jcol++ {
98+
for jrow := ihi; jrow >= jcol+2; jrow-- {
99+
// Step 1: rotate rows JROW-1, JROW to kill A(JROW,JCOL).
100+
var c, s float64
101+
c, s, a[(jrow-1)*lda+jcol] = impl.Dlartg(a[(jrow-1)*lda+jcol], a[jrow*lda+jcol])
102+
a[jrow*lda+jcol] = 0
103+
bi.Drot(n-jcol-1, a[(jrow-1)*lda+jcol+1:], 1,
104+
a[jrow*lda+jcol+1:], 1, c, s)
105+
106+
bi.Drot(n+2-jrow-1, b[(jrow-1)*ldb+jrow-1:], 1,
107+
b[jrow*ldb+jrow-1:], 1, c, s)
108+
109+
if compq != lapack.OrthoNone {
110+
bi.Drot(n, q[jrow-1:], ldq, q[jrow:], ldq, c, s)
111+
}
112+
113+
// Step 2: rotate columns JROW, JROW-1 to kill B(JROW,JROW-1).
114+
c, s, b[jrow*ldb+jrow] = impl.Dlartg(b[jrow*ldb+jrow], b[jrow*ldb+jrow-1])
115+
b[jrow*ldb+jrow-1] = 0
116+
117+
bi.Drot(ihi+1, a[jrow:], lda, a[jrow-1:], lda, c, s)
118+
bi.Drot(jrow, b[jrow:], ldb, b[jrow-1:], ldb, c, s)
119+
120+
if compz != lapack.OrthoNone {
121+
bi.Drot(n, z[jrow:], ldz, z[jrow-1:], ldz, c, s)
122+
}
123+
}
124+
}
125+
}

lapack/gonum/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const (
2121
badMatrixType = "lapack: bad MatrixType"
2222
badMaximizeNormXJob = "lapack: bad MaximizeNormXJob"
2323
badNorm = "lapack: bad Norm"
24+
badOrthoComp = "lapack: bad OrthoComp"
2425
badPivot = "lapack: bad Pivot"
2526
badRightEVJob = "lapack: bad RightEVJob"
2627
badSVDJob = "lapack: bad SVDJob"

lapack/gonum/lapack_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ func TestDgetrs(t *testing.T) {
148148
testlapack.DgetrsTest(t, impl)
149149
}
150150

151+
func TestDgghrd(t *testing.T) {
152+
t.Parallel()
153+
testlapack.DgghrdTest(t, impl)
154+
}
155+
151156
func TestDggsvd3(t *testing.T) {
152157
t.Parallel()
153158
testlapack.Dggsvd3Test(t, impl)

lapack/lapack.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,12 @@ const (
226226
LocalLookAhead MaximizeNormXJob = 0 // Solve Z*x=h-f where h is a vector of ±1.
227227
NormalizedNullVector MaximizeNormXJob = 2 // Compute an approximate null-vector e of Z, normalize e and solve Z*x=±e-f.
228228
)
229+
230+
// OrthoComp specifies whether and how the orthogonal matrix is computed in Dgghrd.
231+
type OrthoComp byte
232+
233+
const (
234+
OrthoNone OrthoComp = 'N' // Do not compute orthogonal matrix.
235+
OrthoUnit OrthoComp = 'I' // Argument is initialized to the unit matrix and the orthogonal matrix is returned.
236+
OrthoEntry OrthoComp = 'V' // Argument Q contains orthogonal matrix Q1 on entry and the product Q1*Q is returned.
237+
)

lapack/testlapack/dgghrd.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright ©2023 The Gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package testlapack
6+
7+
import (
8+
"math"
9+
"testing"
10+
11+
"golang.org/x/exp/rand"
12+
13+
"gonum.org/v1/gonum/blas"
14+
"gonum.org/v1/gonum/blas/blas64"
15+
"gonum.org/v1/gonum/lapack"
16+
)
17+
18+
type Dgghrder interface {
19+
Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int)
20+
}
21+
22+
func DgghrdTest(t *testing.T, impl Dgghrder) {
23+
const tol = 1e-13
24+
const ldAdd = 5
25+
rnd := rand.New(rand.NewSource(1))
26+
comps := []lapack.OrthoComp{lapack.OrthoUnit, lapack.OrthoNone, lapack.OrthoEntry}
27+
for _, compq := range comps {
28+
for _, compz := range comps {
29+
for _, n := range []int{2, 0, 1, 4, 15} {
30+
ldMin := max(1, n)
31+
for _, lda := range []int{ldMin, ldMin + ldAdd} {
32+
for _, ldb := range []int{ldMin, ldMin + ldAdd} {
33+
for _, ldq := range []int{ldMin, ldMin + ldAdd} {
34+
for _, ldz := range []int{ldMin, ldMin + ldAdd} {
35+
testDgghrd(t, impl, rnd, tol, compq, compz, n, 0, n-1, lda, ldb, ldq, ldz)
36+
}
37+
}
38+
}
39+
}
40+
}
41+
}
42+
}
43+
}
44+
45+
func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, tol float64, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
46+
a := randomGeneral(n, n, lda, rnd)
47+
b := blockedUpperTriGeneral(n, n, 0, n, ldb, false, rnd)
48+
var q, q1, z, z1 blas64.General
49+
if compq == lapack.OrthoEntry {
50+
q = randomOrthogonal(n, rnd)
51+
q1 = cloneGeneral(q)
52+
} else {
53+
q = nanGeneral(n, n, ldq)
54+
}
55+
if compz == lapack.OrthoEntry {
56+
z = randomOrthogonal(n, rnd)
57+
z1 = cloneGeneral(z)
58+
} else {
59+
z = nanGeneral(n, n, ldz)
60+
}
61+
62+
hGot := cloneGeneral(a)
63+
tGot := cloneGeneral(b)
64+
for i := 1; i < n; i++ {
65+
for j := 0; j < i; j++ {
66+
// Set all lower tri elems to NaN to catch bad implementations.
67+
tGot.Data[i*tGot.Stride+j] = math.NaN()
68+
}
69+
}
70+
impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, q.Stride, z.Data, z.Stride)
71+
if n == 0 {
72+
return
73+
}
74+
if !isUpperHessenberg(hGot) {
75+
t.Error("H is not upper Hessenberg")
76+
}
77+
if !isNaNFree(tGot) || !isNaNFree(hGot) {
78+
t.Error("T or H is/or not NaN free")
79+
}
80+
if !isUpperTriangular(tGot) {
81+
t.Error("T is not upper triangular")
82+
}
83+
if compq == lapack.OrthoNone {
84+
if !isAllNaN(q.Data) {
85+
t.Errorf("Q is not NaN")
86+
}
87+
return
88+
}
89+
if compz == lapack.OrthoNone {
90+
if !isAllNaN(z.Data) {
91+
t.Errorf("Z is not NaN")
92+
}
93+
return
94+
}
95+
if compq != compz {
96+
return // Do not handle mixed case
97+
}
98+
comp := compq
99+
aux := zeros(n, n, n)
100+
101+
switch comp {
102+
case lapack.OrthoUnit:
103+
// Qᵀ*A*Z = H
104+
hCalc := zeros(n, n, n)
105+
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
106+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
107+
if !equalApproxGeneral(hGot, hCalc, tol) {
108+
t.Errorf("Qᵀ*A*Z != H")
109+
}
110+
111+
// Qᵀ*B*Z = T
112+
tCalc := zeros(n, n, n)
113+
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
114+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
115+
if !equalApproxGeneral(hGot, hCalc, tol) {
116+
t.Errorf("Qᵀ*B*Z != T")
117+
}
118+
case lapack.OrthoEntry:
119+
// Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
120+
lhs := zeros(n, n, n)
121+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
122+
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) // lhs = Q1 * A * Z1ᵀ
123+
124+
rhs := zeros(n, n, n)
125+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
126+
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
127+
if !equalApproxGeneral(lhs, rhs, tol) {
128+
t.Errorf("Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ")
129+
}
130+
131+
// Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
132+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux)
133+
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
134+
135+
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
136+
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
137+
if !equalApproxGeneral(lhs, rhs, tol) {
138+
t.Errorf("Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ")
139+
}
140+
}
141+
}

lapack/testlapack/general.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,20 @@ func isUpperTriangular(a blas64.General) bool {
12011201
n := a.Rows
12021202
for i := 1; i < n; i++ {
12031203
for j := 0; j < i; j++ {
1204-
if a.Data[i*a.Stride+j] != 0 {
1204+
v := a.Data[i*a.Stride+j]
1205+
if v != 0 || math.IsNaN(v) {
1206+
return false
1207+
}
1208+
}
1209+
}
1210+
return true
1211+
}
1212+
1213+
// isNaNFree returns whether a does not contain NaN elements in reachable elements.
1214+
func isNaNFree(a blas64.General) bool {
1215+
for i := 0; i < a.Rows; i++ {
1216+
for j := 0; j < a.Cols; j++ {
1217+
if math.IsNaN(a.Data[i*a.Stride+j]) {
12051218
return false
12061219
}
12071220
}

0 commit comments

Comments
 (0)