Skip to content

Commit cafcbe4

Browse files
authored
lapack/gonum: add Dgetc2 (#1655)
1 parent 0acd651 commit cafcbe4

File tree

3 files changed

+220
-0
lines changed

3 files changed

+220
-0
lines changed

lapack/gonum/dgetc2.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright ©2021 The Gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package gonum
6+
7+
import (
8+
"math"
9+
10+
"gonum.org/v1/gonum/blas/blas64"
11+
)
12+
13+
// Dgetc2 computes an LU factorization with complete pivoting of the
14+
// n×n matrix A. The factorization has the form
15+
// A = P * L * U * Q,
16+
// where P and Q are permutation matrices, L is lower triangular with
17+
// unit diagonal elements and U is upper triangular.
18+
//
19+
// a is modified to the information to construct L and U.
20+
// The lower triangle of a contains the matrix L (not including diagonal).
21+
// The upper triangle contains the matrix U. The matrices P and Q can
22+
// be constructed from ipiv and jpiv, respectively. k is non-negative if U(k, k)
23+
// is likely to produce overflow when we try to solve for x in Ax = b.
24+
// U is perturbed in this case to avoid the overflow.
25+
//
26+
// Dgetc2 is an internal routine. It is exported for testing purposes.
27+
func (impl Implementation) Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int) {
28+
switch {
29+
case n < 0:
30+
panic(nLT0)
31+
case lda < max(1, n):
32+
panic(badLdA)
33+
}
34+
35+
// Negative k indicates U was not perturbed.
36+
k = -1
37+
// Quick return if possible.
38+
if n == 0 {
39+
return k
40+
}
41+
42+
switch {
43+
case len(a) < (n-1)*lda+n:
44+
panic(shortA)
45+
case len(ipiv) != n:
46+
panic(badLenIpiv)
47+
case len(jpiv) != n:
48+
panic(badLenJpvt)
49+
}
50+
51+
const (
52+
eps = dlamchP
53+
smlnum = dlamchS / eps
54+
)
55+
if n == 1 {
56+
ipiv[0], jpiv[0] = 0, 0
57+
if math.Abs(a[0]) < smlnum {
58+
a[0] = smlnum
59+
k = 0
60+
}
61+
return k
62+
}
63+
64+
// Factorize A using complete pivoting.
65+
// Set pivots less than lc to lc.
66+
var lc float64
67+
var ipv, jpv int
68+
bi := blas64.Implementation()
69+
for i := 0; i < n-1; i++ {
70+
xmax := 0.0
71+
for ip := i; ip < n; ip++ {
72+
for jp := i; jp < n; jp++ {
73+
if math.Abs(a[ip*lda+jp]) >= xmax {
74+
xmax = math.Abs(a[ip*lda+jp])
75+
ipv = ip
76+
jpv = jp
77+
}
78+
}
79+
}
80+
if i == 0 {
81+
lc = math.Max(eps*xmax, smlnum)
82+
}
83+
84+
// Swap rows.
85+
if ipv != i {
86+
bi.Dswap(n, a[ipv*lda:], 1, a[i*lda:], 1)
87+
}
88+
ipiv[i] = ipv
89+
90+
// Swap columns.
91+
if jpv != i {
92+
bi.Dswap(n, a[jpv:], lda, a[i:], lda)
93+
}
94+
jpiv[i] = jpv
95+
96+
// Check for singularity.
97+
if math.Abs(a[i*lda+i]) < lc {
98+
k = i
99+
a[i*lda+i] = lc
100+
}
101+
102+
for j := i + 1; j < n; j++ {
103+
a[j*lda+i] /= a[i*lda+i]
104+
}
105+
bi.Dger(n-i-1, n-i-1, -1, a[(i+1)*lda+i:], lda, a[i*lda+i+1:], 1, a[(i+1)*lda+i+1:], lda)
106+
}
107+
108+
if math.Abs(a[(n-1)*lda+n-1]) < lc {
109+
k = n - 1
110+
a[(n-1)*lda+(n-1)] = lc
111+
}
112+
113+
// Set last pivots to last index.
114+
ipiv[n-1] = n - 1
115+
jpiv[n-1] = n - 1
116+
return k
117+
}

lapack/gonum/lapack_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ func TestDgesvd(t *testing.T) {
118118
testlapack.DgesvdTest(t, impl, tol)
119119
}
120120

121+
func TestDgetc2(t *testing.T) {
122+
t.Parallel()
123+
testlapack.Dgetc2Test(t, impl)
124+
}
125+
121126
func TestDgetri(t *testing.T) {
122127
t.Parallel()
123128
testlapack.DgetriTest(t, impl)

lapack/testlapack/dgetc2.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright ©2021 The Gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package testlapack
6+
7+
import (
8+
"fmt"
9+
"math"
10+
"testing"
11+
12+
"golang.org/x/exp/rand"
13+
14+
"gonum.org/v1/gonum/blas"
15+
"gonum.org/v1/gonum/blas/blas64"
16+
)
17+
18+
type Dgetc2er interface {
19+
Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int)
20+
}
21+
22+
func Dgetc2Test(t *testing.T, impl Dgetc2er) {
23+
const tol = 1e-12
24+
rnd := rand.New(rand.NewSource(1))
25+
for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20} {
26+
for _, lda := range []int{n, n + 5} {
27+
dgetc2Test(t, impl, rnd, n, lda, tol)
28+
}
29+
}
30+
}
31+
32+
func dgetc2Test(t *testing.T, impl Dgetc2er, rnd *rand.Rand, n, lda int, tol float64) {
33+
name := fmt.Sprintf("n=%v,lda=%v", n, lda)
34+
if lda == 0 {
35+
lda = 1
36+
}
37+
// Generate a random general matrix A.
38+
a := randomGeneral(n, n, lda, rnd)
39+
// ipiv and jpiv are outputs.
40+
ipiv := make([]int, n)
41+
jpiv := make([]int, n)
42+
for i := 0; i < n; i++ {
43+
ipiv[i], jpiv[i] = -1, -1 // Set to non-indices.
44+
}
45+
// Copy to store output (LU decomposition).
46+
lu := cloneGeneral(a)
47+
k := impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
48+
if k >= 0 {
49+
t.Logf("%v: matrix was perturbed at %d", name, k)
50+
}
51+
52+
// Verify all indices are set.
53+
for i := 0; i < n; i++ {
54+
if ipiv[i] < 0 {
55+
t.Errorf("%v: ipiv[%d] is negative", name, i)
56+
}
57+
if jpiv[i] < 0 {
58+
t.Errorf("%v: jpiv[%d] is negative", name, i)
59+
}
60+
}
61+
bi := blas64.Implementation()
62+
// Construct L and U triangular matrices from Dgetc2 output.
63+
L := zeros(n, n, lda)
64+
U := zeros(n, n, lda)
65+
for i := 0; i < n; i++ {
66+
for j := 0; j < n; j++ {
67+
idx := i*lda + j
68+
if j >= i { // On upper triangle and setting of L's unit diagonal elements.
69+
U.Data[idx] = lu.Data[idx]
70+
if j == i {
71+
L.Data[idx] = 1.0
72+
}
73+
} else if i > j { // On diagonal or lower triangle.
74+
L.Data[idx] = lu.Data[idx]
75+
}
76+
}
77+
}
78+
work := zeros(n, n, lda)
79+
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, L.Data, L.Stride, U.Data, U.Stride, 0, work.Data, work.Stride)
80+
81+
// Apply Permutations P and Q to L*U.
82+
for i := n - 1; i >= 0; i-- {
83+
ipv, jpv := ipiv[i], jpiv[i]
84+
if ipv != i {
85+
bi.Dswap(n, work.Data[i*lda:], 1, work.Data[ipv*lda:], 1)
86+
}
87+
if jpv != i {
88+
bi.Dswap(n, work.Data[i:], work.Stride, work.Data[jpv:], work.Stride)
89+
}
90+
}
91+
92+
// A should be reconstructed by now.
93+
for i := range work.Data {
94+
if math.Abs(work.Data[i]-a.Data[i]) > tol {
95+
t.Errorf("%v: matrix %d idx not equal after reconstruction. got %g, expected %g", name, i, work.Data[i], a.Data[i])
96+
}
97+
}
98+
}

0 commit comments

Comments
 (0)