Skip to content

Commit 1f29d7b

Browse files
tvknkortschak
authored andcommitted
mat: calculate Q elements lazily when calling QR.At
When a matrix is very tall, calculating Q will currently allocate a large Q at the end of the factorisation, even if it is not going to be used, and a large Q matrix can lead to out of memory issues. For this reason, Q is never eagerly computed unless explicitly required to by the user, with QR.ToQ. To keep fulfilling the Matrix interface, the QR.At method will compute the requested element only, which only require computing a single row of Q.
1 parent f1a62e1 commit 1f29d7b

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

mat/qr.go

+46-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@ func (qr *QR) Dims() (r, c int) {
3131
return qr.qr.Dims()
3232
}
3333

34-
// At returns the element at row i, column j.
34+
// At returns the element at row i, column j. At will panic if the receiver
35+
// does not contain a successful factorization.
3536
func (qr *QR) At(i, j int) float64 {
37+
if !qr.isValid() {
38+
panic(badQR)
39+
}
40+
3641
m, n := qr.Dims()
3742
if uint(i) >= uint(m) {
3843
panic(ErrRowAccess)
@@ -41,13 +46,46 @@ func (qr *QR) At(i, j int) float64 {
4146
panic(ErrColAccess)
4247
}
4348

49+
if qr.q == nil || qr.q.IsEmpty() {
50+
// Calculate Qi, Q i-th row
51+
qi := getFloat64s(m, true)
52+
qr.qRowTo(i, qi)
53+
54+
// Compute QR(i,j)
55+
var val float64
56+
for k := 0; k <= j; k++ {
57+
val += qi[k] * qr.qr.at(k, j)
58+
}
59+
putFloat64s(qi)
60+
return val
61+
}
62+
4463
var val float64
4564
for k := 0; k <= j; k++ {
4665
val += qr.q.at(i, k) * qr.qr.at(k, j)
4766
}
4867
return val
4968
}
5069

70+
// qRowTo extracts the i-th row of the orthonormal matrix Q from a QR
71+
// decomposition.
72+
func (qr *QR) qRowTo(i int, dst []float64) {
73+
c := blas64.General{
74+
Rows: 1,
75+
Cols: len(dst),
76+
Stride: len(dst),
77+
Data: dst,
78+
}
79+
c.Data[i] = 1 // C is the i-th unit vector
80+
81+
// Construct Qi from the elementary reflectors: Qi = C * (H(1) H(2) ... H(nTau))
82+
work := []float64{0}
83+
lapack64.Ormqr(blas.Right, blas.NoTrans, qr.qr.mat, qr.tau, c, work, -1)
84+
work = getFloat64s(int(work[0]), false)
85+
lapack64.Ormqr(blas.Right, blas.NoTrans, qr.qr.mat, qr.tau, c, work, len(work))
86+
putFloat64s(work)
87+
}
88+
5189
// T performs an implicit transpose by returning the receiver inside a
5290
// Transpose.
5391
func (qr *QR) T() Matrix {
@@ -98,7 +136,9 @@ func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
98136
lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
99137
putFloat64s(work)
100138
qr.updateCond(norm)
101-
qr.updateQ()
139+
if qr.q != nil {
140+
qr.q.Reset()
141+
}
102142
}
103143

104144
func (qr *QR) updateQ() {
@@ -192,6 +232,10 @@ func (qr *QR) QTo(dst *Dense) {
192232
panic(ErrShape)
193233
}
194234
}
235+
236+
if qr.q == nil || qr.q.IsEmpty() {
237+
qr.updateQ()
238+
}
195239
dst.Copy(qr.q)
196240
}
197241

mat/qr_test.go

+28-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ func TestQR(t *testing.T) {
1818
rnd := rand.New(rand.NewSource(1))
1919
for _, test := range []struct {
2020
m, n int
21+
big bool
2122
}{
22-
{5, 5},
23-
{10, 5},
23+
{m: 5, n: 5},
24+
{m: 10, n: 5},
25+
{m: 1e5, n: 3, big: true}, // Test that very tall matrices do not OoM.
2426
} {
2527
m := test.m
2628
n := test.n
@@ -35,6 +37,13 @@ func TestQR(t *testing.T) {
3537

3638
var qr QR
3739
qr.Factorize(a)
40+
if test.big {
41+
_ = qr.At(0, 0) // should not panic, even for big matrices
42+
_ = qr.At(m-1, n-1) // should not panic, even for big matrices
43+
// We cannot proceed past here for big matrices.
44+
continue
45+
}
46+
3847
var q, r Dense
3948
qr.QTo(&q)
4049

@@ -56,6 +65,23 @@ func TestQR(t *testing.T) {
5665
if !EqualApprox(&got, &want, 1e-12) {
5766
t.Errorf("QR does not equal original matrix. \nWant: %v\nGot: %v", want, got)
5867
}
68+
69+
// Verify indirect QR.At()
70+
got.Reset()
71+
got.ReuseAs(m, n)
72+
qr.q.Reset() // reset q matrix to force lazy computation
73+
for i := 0; i < m; i++ {
74+
for j := 0; j < n; j++ {
75+
got.set(i, j, qr.At(i, j))
76+
}
77+
}
78+
79+
if !EqualApprox(a, &got, 1e-14) {
80+
t.Errorf("m=%d,n=%d: A and QR (computed with QR.At()) are not equal", m, n)
81+
}
82+
if !EqualApprox(a.T(), got.T(), 1e-14) {
83+
t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ (computed with QR.At()) are not equal", m, n)
84+
}
5985
}
6086
}
6187

0 commit comments

Comments
 (0)