-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathlinalg.py
More file actions
1548 lines (1311 loc) · 61.6 KB
/
linalg.py
File metadata and controls
1548 lines (1311 loc) · 61.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions for common linear algebra operations.
Note: Many of these functions will eventually be migrated to core TensorFlow.
"""
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.math import generic
__all__ = [
'cholesky_concat',
'cholesky_update',
'fill_triangular',
'fill_triangular_inverse',
'hpsd_logdet',
'hpsd_quadratic_form_solve',
'hpsd_quadratic_form_solvevec',
'hpsd_solve',
'hpsd_solvevec',
'lu_matrix_inverse',
'lu_reconstruct',
'lu_reconstruct_assertions', # Internally visible for MatvecLU.
'lu_solve',
'low_rank_cholesky',
'pivoted_cholesky',
'sparse_or_dense_matmul',
'sparse_or_dense_matvecmul',
]
def cholesky_concat(chol, cols, name=None):
"""Concatenates `chol @ chol.T` with additional rows and columns.
This operation is conceptually identical to:
```python
def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m
mat = tf.matmul(chol, chol, adjoint_b=True) # shape of n x n
# Concat columns.
mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z
# Concat rows.
mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z
return tf.linalg.cholesky(mat)
```
but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
`cholesky_concat` only costs `O(z**2 + m**3)` work.
The resulting (implicit) matrix must be symmetric and positive definite.
Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
separate `rows` argument (which can be inferred from `conj(cols.T)`).
Args:
chol: Cholesky decomposition of `mat = chol @ chol.T`.
cols: The new columns whose first `n` rows we would like concatenated to the
right of `mat = chol @ chol.T`, and whose conjugate transpose we would
like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
(their conjugate transpose forms the bottom left), and the bottom `m x m`
is self-adjoint.
name: Optional name for this op.
Returns:
chol_concat: The Cholesky decomposition of:
```
[ [ mat cols[:n, :] ]
[ conj(cols.T) ] ]
```
"""
with tf.name_scope(name or 'cholesky_extend'):
dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
n = ps.shape(chol)[-1]
mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
solved_nm = tf.linalg.triangular_solve(chol, mat_nm)
lower_right_mm = tf.linalg.cholesky(
mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
out_batch = ps.shape(solved_nm)[:-2]
chol = tf.broadcast_to(
chol, ps.concat([out_batch, ps.shape(chol)[-2:]], axis=0))
top_right_zeros_nm = tf.zeros_like(solved_nm)
return tf.concat([tf.concat([chol, top_right_zeros_nm], axis=-1),
tf.concat([lower_left_mn, lower_right_mm], axis=-1)],
axis=-2)
def cholesky_update(chol, update_vector, multiplier=1., name=None):
"""Returns cholesky of chol @ chol.T + multiplier * u @ u.T.
Given a (batch of) lower triangular cholesky factor(s) `chol`, along with a
(batch of) vector(s) `update_vector`, compute the lower triangular cholesky
factor of the rank-1 update `chol @ chol.T + multiplier * u @ u.T`, where
`multiplier` is a (batch of) scalar(s).
If `chol` has shape `[L, L]`, this has complexity `O(L^2)` compared to the
naive algorithm which has complexity `O(L^3)`.
Args:
chol: Floating-point `Tensor` with shape `[B1, ..., Bn, L, L]`.
Cholesky decomposition of `mat = chol @ chol.T`. Batch dimensions
must be broadcastable with `update_vector` and `multiplier`.
update_vector: Floating-point `Tensor` with shape `[B1, ... Bn, L]`. Vector
defining rank-one update. Batch dimensions must be broadcastable with
`chol` and `multiplier`.
multiplier: Floating-point `Tensor` with shape `[B1, ..., Bn]. Scalar
multiplier to rank-one update. Batch dimensions must be broadcastable
with `chol` and `update_vector`. Note that updates where `multiplier` is
positive are numerically stable, while when `multiplier` is negative
(downdating), the update will only work if the new resulting matrix is
still positive definite.
name: Optional name for this op.
#### References
[1] Oswin Krause. Christian Igel. A More Efficient Rank-one Covariance
Matrix Update for Evolution Strategies. 2015 ACM Conference.
https://www.researchgate.net/publication/300581419_A_More_Efficient_Rank-one_Covariance_Matrix_Update_for_Evolution_Strategies
"""
# TODO(b/154638092): Move this functionality in to TensorFlow.
with tf.name_scope(name or 'cholesky_update'):
dtype = dtype_util.common_dtype(
[chol, update_vector, multiplier], dtype_hint=tf.float32)
chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
update_vector = tf.convert_to_tensor(
update_vector, name='update_vector', dtype=dtype)
multiplier = tf.convert_to_tensor(
multiplier, name='multiplier', dtype=dtype)
batch_shape = ps.broadcast_shape(
ps.broadcast_shape(
ps.shape(chol)[:-2],
ps.shape(update_vector)[:-1]), ps.shape(multiplier))
chol = tf.broadcast_to(
chol, ps.concat([batch_shape, ps.shape(chol)[-2:]], axis=0))
update_vector = tf.broadcast_to(
update_vector,
ps.concat([batch_shape, ps.shape(update_vector)[-1:]], axis=0))
multiplier = tf.broadcast_to(multiplier, batch_shape)
chol_diag = tf.linalg.diag_part(chol)
# The algorithm in [1] is implemented as a double for loop. We can treat
# the inner loop in Algorithm 3.1 as a vector operation, and thus the
# whole algorithm as a single for loop, and hence can use a `tf.scan`
# on it.
# We use for accumulation omega and b as defined in Algorithm 3.1, since
# these are updated per iteration.
def compute_new_column(accumulated_quantities, state):
"""Computes the next column of the updated cholesky."""
_, _, omega, b = accumulated_quantities
index, diagonal_member, col, col_mask = state
omega_at_index = omega[..., index]
# Line 4
new_diagonal_member = tf.math.sqrt(
tf.math.square(diagonal_member) + multiplier / b * tf.math.square(
omega_at_index))
# `scaling_factor` is the same as `gamma` on Line 5.
scaling_factor = (tf.math.square(diagonal_member) * b +
multiplier * tf.math.square(omega_at_index))
# The following updates are the same as the for loop in lines 6-8.
omega = omega - (omega_at_index / diagonal_member)[..., tf.newaxis] * col
new_col = new_diagonal_member[..., tf.newaxis] * (
col / diagonal_member[..., tf.newaxis] +
(multiplier * omega_at_index / scaling_factor)[
..., tf.newaxis] * omega * col_mask)
b = b + multiplier * tf.math.square(omega_at_index / diagonal_member)
return new_diagonal_member, new_col, omega, b
# We will scan over the columns.
cols_mask = distribution_util.move_dimension(
tf.linalg.band_part(tf.ones_like(chol), -1, 0),
source_idx=-1, dest_idx=0)
chol = distribution_util.move_dimension(chol, source_idx=-1, dest_idx=0)
chol_diag = distribution_util.move_dimension(
chol_diag, source_idx=-1, dest_idx=0)
new_diag, new_chol, _, _ = tf.scan(
fn=compute_new_column,
elems=(tf.range(0, ps.shape(chol)[0]), chol_diag, chol, cols_mask),
initializer=(
tf.zeros_like(multiplier),
tf.zeros_like(chol[0, ...]),
update_vector,
tf.ones_like(multiplier)))
new_chol = distribution_util.move_dimension(
new_chol, source_idx=0, dest_idx=-1)
new_diag = distribution_util.move_dimension(
new_diag, source_idx=0, dest_idx=-1)
new_chol = tf.linalg.set_diag(new_chol, new_diag)
return new_chol
def _swap_m_with_i(vecs, m, i):
"""Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)
Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
per-vector indices `i`, this function swaps elements `m` and `i` in each
vector. For the use-case below, these are permutation vectors.
Args:
vecs: Vectors on which we perform the swap, int64 `Tensor`.
m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
which the `m`th element is going.
Returns:
vecs: The updated vectors.
"""
vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
trailing_elts = tf.broadcast_to(
tf.range(m + 1, ps.shape(vecs, out_type=tf.int64)[-1]),
ps.shape(vecs[..., m + 1:]))
trailing_elts = tf.where(
tf.equal(trailing_elts, i),
tf.gather(vecs, [m], axis=-1),
vecs[..., m + 1:])
# TODO(bjp): Could we use tensor_scatter_nd_update?
vecs_shape = vecs.shape
vecs_rank = tensorshape_util.rank(vecs_shape)
if vecs_rank is None:
raise NotImplementedError(
'Input vector to swap must have statically known rank. If you cannot '
'provide static rank please contact core TensorFlow team and request '
'that `tf.gather` `batch_dims` argument support `tf.Tensor`-valued '
'inputs.')
vecs = tf.concat([
vecs[..., :m],
tf.gather(vecs, i, batch_dims=int(vecs_rank) - 1),
trailing_elts
], axis=-1)
tensorshape_util.set_shape(vecs, vecs_shape)
return vecs
def _invert_permutation(perm): # TODO(b/130217510): Remove this function.
return tf.cast(tf.argsort(perm, axis=-1), perm.dtype)
def pivoted_cholesky(matrix,
max_rank,
diag_rtol=1e-3,
return_pivoting_order=False,
name=None):
"""Computes the (partial) pivoted cholesky decomposition of `matrix`.
The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
currently-worst-approximated diagonal element is selected as the pivot at each
iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
a rectangular-matrix sense. However, under a permutation it could be made
triangular (it has one more zero in each column as you move to the right).
Such a matrix can be useful as a preconditioner for conjugate gradient
optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
cheaply done via the Woodbury matrix identity, as implemented by
`tf.linalg.LinearOperatorLowRankUpdate`.
Args:
matrix: Floating point `Tensor` batch of symmetric, positive definite
matrices.
max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
approximation.
diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
errors of all diagonal elements of `lr @ lr.T` are each lower than
`element * diag_rtol`, iteration is permitted to terminate early.
return_pivoting_order: If `True`, return an `int` `Tensor` indicating the
pivoting order used to produce `lr` (in addition to `lr`).
name: Optional name for the op.
Returns:
lr: Low rank pivoted Cholesky approximation of `matrix`.
perm: (Optional) pivoting order used to produce `lr`.
#### References
[1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
pivoted Cholesky decomposition. _Applied numerical mathematics_,
62(4):428-440, 2012.
[2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
_arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
"""
with tf.name_scope(name or 'pivoted_cholesky'):
dtype = dtype_util.common_dtype([matrix, diag_rtol],
dtype_hint=tf.float32)
if not isinstance(matrix, tf.linalg.LinearOperator):
matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
if tensorshape_util.rank(matrix.shape) is None:
raise NotImplementedError('Rank of `matrix` must be known statically')
if isinstance(matrix, tf.linalg.LinearOperator):
matrix_shape = tf.cast(matrix.shape_tensor(), tf.int64)
else:
matrix_shape = ps.shape(matrix, out_type=tf.int64)
max_rank = tf.convert_to_tensor(
max_rank, name='max_rank', dtype=tf.int64)
max_rank = tf.minimum(max_rank, matrix_shape[-1])
diag_rtol = tf.convert_to_tensor(
diag_rtol, dtype=dtype, name='diag_rtol')
matrix_diag = tf.linalg.diag_part(matrix)
# matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
orig_error = tf.reduce_max(matrix_diag, axis=-1)
def cond(m, pchol, perm, matrix_diag):
"""Condition for `tf.while_loop` continuation."""
del pchol
del perm
error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
max_err = tf.reduce_max(error / orig_error)
return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))
batch_dims = tensorshape_util.rank(matrix.shape) - 2
def batch_gather(params, indices, axis=-1):
return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)
def body(m, pchol, perm, matrix_diag):
"""Body of a single `tf.while_loop` iteration."""
# Here is roughly a numpy, non-batched version of what's going to happen.
# (See also Algorithm 1 of Harbrecht et al.)
# 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
# 2: maxval = matrix_diag[perm][maxi]
# 3: perm[m], perm[maxi] = perm[maxi], perm[m]
# 4: row = matrix[perm[m]][perm[m + 1:]]
# 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
# 6: pivot = np.sqrt(maxval); row /= pivot
# 7: row = np.concatenate([[[pivot]], row], -1)
# 8: matrix_diag[perm[m:]] -= row**2
# 9: pchol[m, perm[m:]] = row
# Find the maximal position of the (remaining) permuted diagonal.
# Steps 1, 2 above.
permuted_diag = batch_gather(matrix_diag, perm[..., m:])
maxi = tf.argmax(
permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis]
maxval = batch_gather(permuted_diag, maxi)
maxi = maxi + m
maxval = maxval[..., 0]
# Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
perm = _swap_m_with_i(perm, m, maxi)
# Step 4.
if callable(getattr(matrix, 'row', None)):
row = matrix.row(perm[..., m])[..., tf.newaxis, :]
else:
row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
row = batch_gather(row, perm[..., m + 1:])
# Step 5.
prev_rows = pchol[..., :m, :]
prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:])
prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
row -= tf.reduce_sum(
prev_rows_perm_m_onward * prev_rows_pivot_col,
axis=-2)[..., tf.newaxis, :]
# Step 6.
pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
# Step 7.
row = tf.concat([pivot, row / pivot], axis=-1)
# TODO(b/130899118): Pad grad fails with int64 paddings.
# Step 8.
paddings = tf.concat([
tf.zeros([ps.rank(pchol) - 1, 2], dtype=tf.int32),
[[tf.cast(m, tf.int32), 0]]], axis=0)
diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
reverse_perm = _invert_permutation(perm)
matrix_diag = matrix_diag - batch_gather(diag_update, reverse_perm)
# Step 9.
row = tf.pad(row, paddings=paddings)
# TODO(bjp): Defer the reverse permutation all-at-once at the end?
row = batch_gather(row, reverse_perm)
pchol_shape = pchol.shape
pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
axis=-2)
tensorshape_util.set_shape(pchol, pchol_shape)
return m + 1, pchol, perm, matrix_diag
m = np.int64(0)
pchol = tf.zeros(matrix_shape, dtype=matrix.dtype)[..., :max_rank, :]
perm = tf.broadcast_to(
ps.range(matrix_shape[-1]), matrix_shape[:-1])
_, pchol, _, _ = tf.while_loop(
cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag))
pchol = tf.linalg.matrix_transpose(pchol)
tensorshape_util.set_shape(
pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
if return_pivoting_order:
return pchol, perm
else:
return pchol
def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None):
"""Computes a low-rank approximation to the Cholesky decomposition.
This routine is similar to pivoted_cholesky, but works under JAX, at
the cost of being slightly less numerically stable.
Args:
matrix: Floating point `Tensor` batch of symmetric, positive definite
matrices, or a tf.linalg.LinearOperator.
max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
approximation.
trace_atol: Scalar floating point `Tensor` (same dtype as `matrix`). If
trace_atol > 0 and trace(matrix - LR * LR^t) < trace_atol, the output
LR matrix is allowed to be of rank less than max_rank.
trace_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If
trace_rtol > 0 and trace(matrix - LR * LR^t) < trace_rtol * trace(matrix),
the output LR matrix is allowed to be of rank less than max_rank.
name: Optional name for the op.
Returns:
A triplet (LR, r, residual_diag) of
LR: a matrix such that LR * LR^t is approximately the input matrix.
If matrix is of shape (b1, ..., bn, m, m), then LR will be of shape
(b1, ..., bn, m, r) where r <= max_rank.
r: the rank of LR. If r is < max_rank, then
trace(matrix - LR * LR^t) < trace_atol, and
residual_diag: The diagonal entries of matrix - LR * LR^t. This is
returned because together with LR, it is useful for preconditioning
the input matrix.
"""
with tf.name_scope(name or 'low_rank_cholesky'):
dtype = dtype_util.common_dtype([matrix, trace_atol, trace_rtol],
dtype_hint=tf.float32)
if not isinstance(matrix, tf.linalg.LinearOperator):
matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
matrix = tf.linalg.LinearOperatorFullMatrix(matrix)
mtrace = matrix.trace()
mrank = tensorshape_util.rank(matrix.shape)
batch_dims = mrank - 2
def lr_cholesky_cond(i, _, residual_diag):
"""Condition for `tf.while_loop` continuation."""
residual_trace = tf.math.reduce_sum(residual_diag, axis=-1)
atol_terminate = (trace_atol > 0) & tf.reduce_all(
residual_trace < trace_atol)
rtol_terminate = (trace_rtol > 0) & tf.reduce_all(
residual_trace < trace_rtol * mtrace)
terminate = atol_terminate | rtol_terminate
# TODO(thomaswc): Return false even if i == 0 when mtrace == 0.0 to
# avoid division by zero errors.
return (i == 0) | ~terminate
def lr_cholesky_body(i, lr, residual_diag):
# 1. Find the maximum entry of the residual diagonal.
max_j = tf.argmax(
residual_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis]
# 2. Construct vector v that kills that diagonal entry and its row & col.
# v = residual_matrix[max_j, :] / sqrt(residual_matrix[max_j, max_j])
maxval = tf.gather(
residual_diag, max_j, axis=-1, batch_dims=batch_dims)[..., 0]
normalizer = tf.sqrt(maxval)
if callable(getattr(matrix, 'row', None)):
matrix_row = tf.squeeze(matrix.row(max_j), axis=-2)
else:
matrix_row = tf.gather(
matrix.to_dense(), max_j, axis=-1, batch_dims=batch_dims)[..., 0]
# residual_matrix[max_j, :] = matrix_row[max_j, :] - (lr * lr^t)[max_j, :]
# And (lr * lr^t)[max_j, :] = lr[max_j, :] * lr^t
lr_row_maxj = tf.gather(lr, max_j, axis=-2, batch_dims=batch_dims)
lr_lrt_row = tf.matmul(lr_row_maxj, lr, transpose_b=True)
lr_lrt_row = tf.squeeze(lr_lrt_row, axis=-2)
unnormalized_v = matrix_row - lr_lrt_row
v = unnormalized_v / normalizer[..., tf.newaxis]
# Mask v so that it is zero in row/columns we've already zerod.
# We can use the sign of the residual_diag as the mask because the input
# matrix being positive definite implies that the diag starts off
# positive, and only becomes zero on the entries that we've chosen
# in previous iterations.
v = v * tf.math.sign(residual_diag)
# 3. Add v to lr.
# Conceptually the same as
# new_lr = lr
# new_lr[..., i] = v
# but without using assignment or dynamic slices, both of which don't
# work under JAX.
# v[..., tf.newaxis] is of shape (batch1, ..., batchn, m, 1) and
# the one_hot term is of shape (1, max_rank) so their broadcasted product
# will be of shape (batch1, ..., batchn, m, max_rank), the same as lr.
new_lr = lr + v[..., tf.newaxis] * tf.one_hot(
indices=i, depth=max_rank, dtype=matrix.dtype)[tf.newaxis, :]
# 4. Compute the new residual_diag = old_residual_diag - v * v
new_residual_diag = residual_diag - v * v
# Explicitly set new_residual_diag[max_j] = 0 (both to guarantee we never
# choose its index again, and to let us use the tf.math.sign of the
# residual as a mask.)
n = new_residual_diag.shape[-1]
oh = tf.one_hot(
indices=max_j[..., 0], depth=n, on_value=0.0, off_value=1.0,
dtype=new_residual_diag.dtype
)
new_residual_diag = new_residual_diag * oh
return i + 1, new_lr, new_residual_diag
lr = tf.zeros(matrix.shape, dtype=matrix.dtype)[..., :max_rank]
mdiag = matrix.diag_part()
i, lr, residual_diag = tf.while_loop(
cond=lr_cholesky_cond,
body=lr_cholesky_body,
loop_vars=(0, lr, mdiag),
maximum_iterations=max_rank
)
return lr, i, residual_diag
def lu_solve(lower_upper, perm, rhs,
validate_args=False,
name=None):
"""Solves systems of linear eqns `A X = RHS`, given LU factorizations.
Note: this function does not verify the implied matrix is actually invertible
nor is this condition checked even when `validate_args=True`.
Args:
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
perm: `p` as returned by `tf.linag.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
`A X = RHS`. To handle vector cases, use:
`lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
validate_args: Python `bool` indicating whether arguments should be checked
for correctness. Note: this function does not verify the implied matrix is
actually invertible, even when `validate_args=True`.
Default value: `False` (i.e., don't validate arguments).
name: Python `str` name given to ops managed by this object.
Default value: `None` (i.e., 'lu_solve').
Returns:
x: The `X` in `A @ X = RHS`.
#### Examples
```python
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
x = [[[1., 2],
[3, 4]],
[[7, 8],
[3, 4]]]
inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
tf.assert_near(tf.matrix_inverse(x), inv_x)
# ==> True
```
"""
with tf.name_scope(name or 'lu_solve'):
lower_upper = tf.convert_to_tensor(
lower_upper, dtype_hint=tf.float32, name='lower_upper')
perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
rhs = tf.convert_to_tensor(
rhs, dtype_hint=lower_upper.dtype, name='rhs')
assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
if assertions:
with tf.control_dependencies(assertions):
lower_upper = tf.identity(lower_upper)
perm = tf.identity(perm)
rhs = tf.identity(rhs)
if (tensorshape_util.rank(rhs.shape) == 2 and
tensorshape_util.rank(perm.shape) == 1):
# Both rhs and perm have scalar batch_shape.
permuted_rhs = tf.gather(rhs, perm, axis=-2)
else:
# Either rhs or perm have non-scalar batch_shape or we can't determine
# this information statically.
rhs_shape = tf.shape(rhs)
broadcast_batch_shape = tf.broadcast_dynamic_shape(
rhs_shape[:-2],
tf.shape(perm)[:-1])
d, m = rhs_shape[-2], rhs_shape[-1]
rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0)
# Tile out rhs.
broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])
# Tile out perm and add batch indices.
broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
broadcast_batch_indices = tf.broadcast_to(
tf.range(broadcast_batch_size)[:, tf.newaxis],
[broadcast_batch_size, d])
broadcast_perm = tf.stack([broadcast_batch_indices, broadcast_perm],
axis=-1)
permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)
lower = tf.linalg.set_diag(
tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
return tf.linalg.triangular_solve(
lower_upper, # Only upper is accessed.
tf.linalg.triangular_solve(lower, permuted_rhs), lower=False)
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
"""Computes a matrix inverse given the matrix's LU decomposition.
This op is conceptually identical to,
```python
inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
tf.assert_near(tf.matrix_inverse(X), inv_X)
# ==> True
```
Note: this function does not verify the implied matrix is actually invertible
nor is this condition checked even when `validate_args=True`.
Args:
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
perm: `p` as returned by `tf.linag.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
validate_args: Python `bool` indicating whether arguments should be checked
for correctness. Note: this function does not verify the implied matrix is
actually invertible, even when `validate_args=True`.
Default value: `False` (i.e., don't validate arguments).
name: Python `str` name given to ops managed by this object.
Default value: `None` (i.e., 'lu_matrix_inverse').
Returns:
inv_x: The matrix_inv, i.e.,
`tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm))`.
#### Examples
```python
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
x = [[[3., 4], [1, 2]],
[[7., 8], [3, 4]]]
inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x))
tf.assert_near(tf.matrix_inverse(x), inv_x)
# ==> True
```
"""
with tf.name_scope(name or 'lu_matrix_inverse'):
lower_upper = tf.convert_to_tensor(
lower_upper, dtype_hint=tf.float32, name='lower_upper')
perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
if assertions:
with tf.control_dependencies(assertions):
lower_upper = tf.identity(lower_upper)
perm = tf.identity(perm)
shape = tf.shape(lower_upper)
return lu_solve(
lower_upper, perm,
rhs=tf.eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype),
validate_args=False)
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
"""The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.
Args:
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
perm: `p` as returned by `tf.linag.lu`, i.e., if
`matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
validate_args: Python `bool` indicating whether arguments should be checked
for correctness.
Default value: `False` (i.e., don't validate arguments).
name: Python `str` name given to ops managed by this object.
Default value: `None` (i.e., 'lu_reconstruct').
Returns:
x: The original input to `tf.linalg.lu`, i.e., `x` as in,
`lu_reconstruct(*tf.linalg.lu(x))`.
#### Examples
```python
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
x = [[[3., 4], [1, 2]],
[[7., 8], [3, 4]]]
x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
tf.assert_near(x, x_reconstructed)
# ==> True
```
"""
with tf.name_scope(name or 'lu_reconstruct'):
lower_upper = tf.convert_to_tensor(
lower_upper, dtype_hint=tf.float32, name='lower_upper')
perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
if assertions:
with tf.control_dependencies(assertions):
lower_upper = tf.identity(lower_upper)
perm = tf.identity(perm)
shape = tf.shape(lower_upper)
lower = tf.linalg.set_diag(
tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
tf.ones(shape[:-1], dtype=lower_upper.dtype))
upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
x = tf.matmul(lower, upper)
if (tensorshape_util.rank(lower_upper.shape) is None or
tensorshape_util.rank(lower_upper.shape) != 2):
# We either don't know the batch rank or there are >0 batch dims.
batch_size = tf.reduce_prod(shape[:-2])
d = shape[-1]
x = tf.reshape(x, [batch_size, d, d])
perm = tf.reshape(perm, [batch_size, d])
perm = tf.map_fn(tf.math.invert_permutation, perm)
batch_indices = tf.broadcast_to(
tf.range(batch_size)[:, tf.newaxis],
[batch_size, d])
x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
x = tf.reshape(x, shape)
else:
x = tf.gather(x, tf.math.invert_permutation(perm))
tensorshape_util.set_shape(x, lower_upper.shape)
return x
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
"""Returns list of assertions related to `lu_reconstruct` assumptions."""
assertions = []
message = 'Input `lower_upper` must have at least 2 dimensions.'
if tensorshape_util.rank(lower_upper.shape) is not None:
if tensorshape_util.rank(lower_upper.shape) < 2:
raise ValueError(message)
elif validate_args:
assertions.append(
assert_util.assert_rank_at_least(lower_upper, rank=2, message=message))
message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
if (tensorshape_util.rank(lower_upper.shape) is not None and
tensorshape_util.rank(perm.shape) is not None):
if (tensorshape_util.rank(lower_upper.shape) !=
tensorshape_util.rank(perm.shape) + 1):
raise ValueError(message)
elif validate_args:
assertions.append(
assert_util.assert_rank(
lower_upper, rank=tf.rank(perm) + 1, message=message))
message = '`lower_upper` must be square.'
if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
if lower_upper.shape[-2] != lower_upper.shape[-1]:
raise ValueError(message)
elif validate_args:
m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
assertions.append(assert_util.assert_equal(m, n, message=message))
return assertions
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
"""Returns list of assertions related to `lu_solve` assumptions."""
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
message = 'Input `rhs` must have at least 2 dimensions.'
if tensorshape_util.rank(rhs.shape) is not None:
if tensorshape_util.rank(rhs.shape) < 2:
raise ValueError(message)
elif validate_args:
assertions.append(
assert_util.assert_rank_at_least(rhs, rank=2, message=message))
message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None and
tf.compat.dimension_value(rhs.shape[-2]) is not None):
if lower_upper.shape[-1] != rhs.shape[-2]:
raise ValueError(message)
elif validate_args:
assertions.append(
assert_util.assert_equal(
tf.shape(lower_upper)[-1],
tf.shape(rhs)[-2],
message=message))
return assertions
def sparse_or_dense_matmul(sparse_or_dense_a,
dense_b,
validate_args=False,
name=None,
**kwargs):
"""Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.
Args:
sparse_or_dense_a: `SparseTensor` or `Tensor` representing a (batch of)
matrices.
dense_b: `Tensor` representing a (batch of) matrices, with the same batch
shape as `sparse_or_dense_a`. The shape must be compatible with the shape
of `sparse_or_dense_a` and kwargs.
validate_args: When `True`, additional assertions might be embedded in the
graph.
Default value: `False` (i.e., no graph assertions are added).
name: Python `str` prefixed to ops created by this function.
Default value: 'sparse_or_dense_matmul'.
**kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or
`tf.matmul`.
Returns:
product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
dtype as `sparse_or_dense_a` and `dense_b`. If `sparse_or_dense_a` or
`dense_b` is adjointed through `kwargs` then the shape is adjusted
accordingly.
"""
with tf.name_scope(name or 'sparse_or_dense_matmul'):
dense_b = tf.convert_to_tensor(
dense_b, dtype_hint=tf.float32, name='dense_b')
if validate_args:
assert_a_rank_at_least_2 = assert_util.assert_rank_at_least(
sparse_or_dense_a,
rank=2,
message='Input `sparse_or_dense_a` must have at least 2 dimensions.')
assert_b_rank_at_least_2 = assert_util.assert_rank_at_least(
dense_b,
rank=2,
message='Input `dense_b` must have at least 2 dimensions.')
with tf.control_dependencies(
[assert_a_rank_at_least_2, assert_b_rank_at_least_2]):
sparse_or_dense_a = tf.identity(sparse_or_dense_a)
dense_b = tf.identity(dense_b)
if isinstance(sparse_or_dense_a, (tf.SparseTensor, tf1.SparseTensorValue)):
return _sparse_tensor_dense_matmul(sparse_or_dense_a, dense_b, **kwargs)
else:
return tf.matmul(sparse_or_dense_a, dense_b, **kwargs)
def sparse_or_dense_matvecmul(sparse_or_dense_matrix,
dense_vector,
validate_args=False,
name=None,
**kwargs):
"""Returns (batched) matmul of a (sparse) matrix with a column vector.
Args:
sparse_or_dense_matrix: `SparseTensor` or `Tensor` representing a (batch of)
matrices.
dense_vector: `Tensor` representing a (batch of) vectors, with the same
batch shape as `sparse_or_dense_matrix`. The shape must be compatible with
the shape of `sparse_or_dense_matrix` and kwargs.
validate_args: When `True`, additional assertions might be embedded in the
graph.
Default value: `False` (i.e., no graph assertions are added).
name: Python `str` prefixed to ops created by this function.
Default value: 'sparse_or_dense_matvecmul'.
**kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or
`tf.matmul`.
Returns:
product: A dense (batch of) vector-shaped Tensor of the same batch shape and
dtype as `sparse_or_dense_matrix` and `dense_vector`.
"""
with tf.name_scope(name or 'sparse_or_dense_matvecmul'):
dense_vector = tf.convert_to_tensor(
dense_vector, dtype_hint=tf.float32, name='dense_vector')
return tf.squeeze(
sparse_or_dense_matmul(
sparse_or_dense_matrix,
dense_vector[..., tf.newaxis],
validate_args=validate_args,
**kwargs),
axis=[-1])
def fill_triangular(x, upper=False, name=None):
"""Creates a (batch of) triangular matrix from a vector of inputs.
Created matrix can be lower- or upper-triangular. (It is more efficient to
create the matrix as upper or lower, rather than transpose.)
Triangular matrix elements are filled in a clockwise spiral. See example,
below.
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
Example:
```python
fill_triangular([1, 2, 3, 4, 5, 6])
# ==> [[4, 0, 0],
# [6, 5, 0],
# [3, 2, 1]]
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
# ==> [[1, 2, 3],
# [0, 5, 6],
# [0, 0, 4]]
```
The key trick is to create an upper triangular matrix by concatenating `x`
and a tail of itself, then reshaping.
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
the first (`n = 5`) elements removed and reversed:
```python
x = np.arange(15) + 1
xc = np.concatenate([x, x[5:][::-1]])
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
# 12, 11, 10, 9, 8, 7, 6])
# (We add one to the arange result to disambiguate the zeros below the
# diagonal of our upper-triangular matrix from the first entry in `x`.)
# Now, when reshapedlay this out as a matrix:
y = np.reshape(xc, [5, 5])
# ==> array([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [15, 14, 13, 12, 11],
# [10, 9, 8, 7, 6]])
# Finally, zero the elements below the diagonal:
y = np.triu(y, k=0)
# ==> array([[ 1, 2, 3, 4, 5],
# [ 0, 7, 8, 9, 10],
# [ 0, 0, 13, 14, 15],
# [ 0, 0, 0, 12, 11],
# [ 0, 0, 0, 0, 6]])
```
From this example we see that the resuting matrix is upper-triangular, and
contains all the entries of x, as desired. The rest is details:
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
`n / 2` rows and half of an additional row), but the whole scheme still
works.
- If we want a lower triangular matrix instead of an upper triangular,
we remove the first `n` elements from `x` rather than from the reversed
`x`.
For additional comparisons, a pure numpy version of this function can be found
in `distribution_util_test.py`, function `_fill_triangular`.
Args: