Skip to content

Unconstrain transform for Wishart#8246

Merged
ricardoV94 merged 2 commits into
pymc-devs:v6from
ricardoV94:Wishart
May 4, 2026
Merged

Unconstrain transform for Wishart#8246
ricardoV94 merged 2 commits into
pymc-devs:v6from
ricardoV94:Wishart

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 9, 2026

Similar idea as #7380 (but this is actually simpler). Almost the same rewrite as LKJCholeskyCov, except here we unconstrain to the full dense matrix.

Closes #8196 (it's now usable)

@ricardoV94 ricardoV94 changed the title Unconstraint Wishart Unconstrain Wishart Apr 9, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 9, 2026

Codecov Report

❌ Patch coverage is 90.72165% with 9 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (v6@5fa428c). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pymc/distributions/multivariate.py 86.95% 9 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@          Coverage Diff          @@
##             v6    #8246   +/-   ##
=====================================
  Coverage      ?   91.97%           
=====================================
  Files         ?      124           
  Lines         ?    20272           
  Branches      ?        0           
=====================================
  Hits          ?    18645           
  Misses        ?     1627           
  Partials      ?        0           
Files with missing lines Coverage Δ
pymc/distributions/transforms.py 100.00% <100.00%> (ø)
pymc/testing.py 90.87% <100.00%> (ø)
pymc/distributions/multivariate.py 95.37% <86.95%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

Took a look at the compiled logp+dlogp, and we pay some price for the whole matrix construction.

Edit: Graph is pretty clean now (see below), collapsing

Details

For an n × n Wishart the unconstrained vector has length n(n+1)/2, with diagonal positions at the cumulative-sum sequence [0, 2, 5, 9, …] of length n.

Full logp+dlogp graph

# logp
Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A] 17
 ├─ Sum{axes=None} [id B] 5
 │  └─ Mul [id C] 3
 │     ├─ [4. 3. 2.] [id D]
 │     └─ AdvancedSubtensor1 [id E] 0
 │        ├─ Sigma_cholesky-cov__ [id F]
 │        └─ [0 2 5] [id G]
 ├─ Sum{axes=None} [id H] 15
 │  └─ Sqr [id I] 13
 │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │        ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │        └─ AdvancedSetSubtensor [id L] 6
 │           ├─ Alloc [id M] 1
 │           │  ├─ 0.0 [id N]
 │           │  ├─ 3 [id O]
 │           │  └─ 3 [id O]
 │           ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] 4
 │           │  ├─ Sigma_cholesky-cov__ [id F]
 │           │  ├─ Exp [id Q] 2
 │           │  │  └─ AdvancedSubtensor1 [id E] 0
 │           │  │     └─ ···
 │           │  └─ [0 2 5] [id G]
 │           ├─ [0 1 1 2 2 2] [id R]
 │           └─ [0 0 1 0 1 2] [id S]
 └─ Sum{axes=None} [id T] 12
    └─ Log [id U] 9
       └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
          └─ AdvancedSetSubtensor [id L] 6
             └─ ···
# dlogp
AdvancedIncSubtensor1{inplace,inc} [id W] 'Sigma_cholesky-cov___grad' 23
 ├─ AdvancedIncSubtensor1{inplace,set} [id X] 21
 │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  ├─ Add [id Z] 18
 │  │  │  ├─ AdvancedSetSubtensor [id BA] 11
 │  │  │  │  ├─ Alloc [id M] 1
 │  │  │  │  │  └─ ···
 │  │  │  │  ├─ Reciprocal [id BB] 8
 │  │  │  │  │  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
 │  │  │  │  │     └─ ···
 │  │  │  │  ├─ [0 1 2] [id BC]
 │  │  │  │  └─ [0 1 2] [id BC]
 │  │  │  └─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id BD] 16
 │  │  │     ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │  │  │     └─ Neg [id BE] 14
 │  │  │        └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │  │  │           └─ ···
 │  │  ├─ [0 1 1 2 2 2] [id R]
 │  │  └─ [0 0 1 0 1 2] [id S]
 │  ├─ [0. 0. 0.] [id BF]
 │  └─ [0 2 5] [id G]
 ├─ Composite{((i1 * i2) + i0)} [id BG] 22
 │  ├─ [4. 3. 2.] [id D]
 │  ├─ AdvancedSubtensor1 [id BH] 20
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G]
 │  └─ Exp [id Q] 2
 │     └─ ···
 └─ [0 2 5] [id G]

Inner graphs:

Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A]
 ← add [id BI]
    ├─ 2.079441547393799 [id BJ]
    ├─ mul [id BK]
    │  ├─ 0.5 [id BL]
    │  └─ sub [id BM]
    │     ├─ add [id BN]
    │     │  ├─ -14.159198660192542 [id BO]
    │     │  └─ mul [id BP]
    │     │     ├─ 2.0 [id BQ]
    │     │     └─ i2 [id BR]
    │     └─ i1 [id BS]
    └─ i0 [id BT]

Composite{((i1 * i2) + i0)} [id BG]
 ← add [id BU]
    ├─ mul [id BV]
    │  ├─ i1 [id BS]
    │  └─ i2 [id BR]
    └─ i0 [id BT]

1. Diagonal gradient routed through an (n, n) scatter

AdvancedSubtensor{idx_list=(0, 1)} [id Y]        ← read length-n(n+1)/2 packed lower-tri
 ├─ Add [id Z]                                    ← add (n,n) matrices
 │  ├─ AdvancedSetSubtensor [id BA]              ← scatter 1/L_kk onto diag of (n,n) zeros
 │  │  ├─ Alloc [id M]                            ← (n,n) zeros
 │  │  ├─ Reciprocal [id BB]                      ← 1/L_kk, length n
 │  │  ├─ [0 1 2]
 │  │  └─ [0 1 2]
 │  └─ SolveTriangular{lower=False} [id BD]       ← full (n,n) −V⁻¹·L gradient term
 ├─ [0 1 1 2 2 2]
 └─ [0 0 1 0 1 2]

The two gradient contributions (1/L_kk on the diagonal, −V⁻¹ L everywhere) are added as full (n, n) matrices, then only the n(n+1)/2 lower-triangular entries are read out. The upper triangle is dead work. The diagonal scatter writes n values into zeros solely so they align with the (n, n) layout of the triangular-solve result.

2. Extracting the diagonal we just placed

ExtractDiag [id V]
 └─ AdvancedSetSubtensor [id L]      ← scatter packed vec into (n,n) zeros
    ├─ Alloc [id M]                   ← (n,n) zeros
    ├─ AdvancedIncSubtensor1 [id P]  ← packed vec with Exp on diag slots
    │  ├─ Sigma_cholesky-cov__ [id F]
    │  ├─ Exp [id Q]                  ← exp(unc[diag_idxs]) = diag(L)
    │  └─ [0 2 5]
    ├─ [0 1 1 2 2 2]
    └─ [0 0 1 0 1 2]

ExtractDiag(L) recovers exactly Exp [id Q], the values we scattered onto the diagonal in the first place. Recognizing this identity simplifies both consumers:

  • logp: Sum(Log(ExtractDiag(L))) = Sum(Log(Exp(unc[diag_idxs]))) = Sum(unc[diag_idxs]).
    The Log, ExtractDiag, and Exp all cancel.

  • dlogp: Reciprocal(ExtractDiag(L)) = 1/Exp(unc[diag_idxs]).
    This is multiplied by L_kk = Exp(unc[diag_idxs]) via the chain rule in the per-diagonal Composite, giving (1/L_kk) · L_kk = 1. The diagonal gradient from the log-det term becomes a constant +1 per diagonal slot, absorbable into the existing log-Jacobian coefficients [n+1, n, …, 2][n+2, n+1, …, 3].

3. Set-then-inc on the same diagonal positions

AdvancedIncSubtensor1{inc} [id W]         ← inc at [0 2 5]
 ├─ AdvancedIncSubtensor1{set} [id X]     ← set [0 2 5] to zero
 │  ├─ [id Y]                              ← length-n(n+1)/2 packed gradient (from §1)
 │  ├─ [0. 0. 0.]
 │  └─ [0 2 5]
 ├─ Composite{(i1 * i2) + i0} [id BG]    ← length-n diagonal contribution
 │  ├─ [4. 3. 2.]                          ← log-Jacobian coefficients
 │  ├─ AdvancedSubtensor1 [id BH]         ← diag slice of [id Y] (read before zeroing)
 │  └─ Exp [id Q]                          ← L_kk
 └─ [0 2 5]

The set zeros the diagonal slots; the inc overwrites them with (Y[diag_idxs] · L_kk) + [4, 3, 2]. The zeroing is an autodiff artifact. With §1–§2 applied, Y[diag_idxs] at the diagonal simplifies (the Reciprocal scatter becomes a constant), and the entire set-then-inc collapses to a single inc_subtensor of one fused length-n vector.

4. Structural lower bound after all three simplifications

AdvancedIncSubtensor1{inc}                ← single inc at packed diag positions
 ├─ AdvancedSubtensor{idx_list=(0, 1)}    ← length-n(n+1)/2 packed lower-tri of −V⁻¹·L
 │  ├─ SolveTriangular{lower=False}       ← (n,n), unchanged
 │  ├─ [0 1 1 2 2 2]
 │  └─ [0 0 1 0 1 2]
 ├─ Composite                             ← length-n fused diagonal contribution
 │  ├─ [n+2, n+1, …, 3]                   ← merged log-Jacobian + log-det constant
 │  └─ Exp(unc[diag_idxs])                ← shared with L construction
 └─ [0 2 5]

What's already good

  • L built once, used three times: forward SolveTriangular, ExtractDiag, and gradient's second triangular solve.
  • Single Alloc: the (n, n) zero buffer is shared between L construction and the §1 diagonal-scatter.
  • Forward solve shared: M = L⁻¹ V serves both ‖M‖²_F (trace term) and −Lᵀ \ M (gradient term).
  • Unconstrained diagonal shared: unc[diag_idxs] and its Exp are CSE'd between L construction and the gradient's chain-rule factor.
  • No Cholesky op: the cholesky_ldotlt rewrite has already eliminated the chol(L Lᵀ) round trip.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

With a few general rewrites (already upstreamed in pymc-devs/pytensor#2061), got it down to this form:

import pymc as pm
import numpy as np

rng = np.random.default_rng(1)
A = rng.normal(size=(n, n))
V = A @ A.T + n * np.eye(n)

with pm.Model() as m:
  pm.Wishart("Sigma", nu=4, V=V)

m.logp_dlogp_function()._pytensor_function.dprint(print_shape=True, print_memory_map=True)
Composite{(2.079441547393799 + (0.5 * (-28.789189739493835 - i1)) + i0)} [id A] shape=() d={0: [0]} 12
 ├─ Sum{axes=None} [id B] shape=() 5
 │  └─ Mul [id C] shape=(?,) 3
 │     ├─ [4. 3. 2.] [id D] shape=(3,)
 │     └─ AdvancedSubtensor1 [id E] shape=(3,) 0
 │        ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
 │        └─ [0 2 5] [id G] shape=(3,)
 └─ Sum{axes=None} [id H] shape=() 10
    └─ Sqr [id I] shape=(3, 3) 8
       └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 7
          ├─ [[1.975771 ... 84325469]] [id K] shape=(3, 3)
          └─ AdvancedSetSubtensor [id L] shape=(3, 3) d={0: [0]} 6
             ├─ Alloc [id M] shape=(3, 3) 1
             │  ├─ 0.0 [id N] shape=()
             │  ├─ 3 [id O] shape=()
             │  └─ 3 [id O] shape=()
             ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] shape=(?,) 4
             │  ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
             │  ├─ Exp [id Q] shape=(?,) 2
             │  │  └─ AdvancedSubtensor1 [id E] shape=(3,) 0
             │  │     └─ ···
             │  └─ [0 2 5] [id G] shape=(3,)
             ├─ [0 1 1 2 2 2] [id R] shape=(6,)
             └─ [0 0 1 0 1 2] [id S] shape=(6,)
AdvancedIncSubtensor1{inplace,set} [id T] shape=(6,) 'Sigma_cholesky-cov___grad' d={0: [0]} 16
 ├─ AdvancedSubtensor{idx_list=(0, 1)} [id U] shape=(6,) 13
 │  ├─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id V] shape=(3, 3) d={0: [1]} 11
 │  │  ├─ [[1.975771 ... 84325469]] [id W] shape=(3, 3)
 │  │  └─ Neg [id X] shape=(3, 3) d={0: [0]} 9
 │  │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 7
 │  │        └─ ···
 │  ├─ [0 1 1 2 2 2] [id R] shape=(6,)
 │  └─ [0 0 1 0 1 2] [id S] shape=(6,)
 ├─ Composite{((i1 * i2) + i0)} [id Y] shape=(3,) d={0: [1]} 15
 │  ├─ [4. 3. 2.] [id D] shape=(3,)
 │  ├─ AdvancedSubtensor1 [id Z] shape=(3,) 14
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id U] shape=(6,) 13
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G] shape=(3,)
 │  └─ Exp [id Q] shape=(?,) 2
 │     └─ ···
 └─ [0 2 5] [id G] shape=(3,)

Inner graphs:

Composite{(2.079441547393799 + (0.5 * (-28.789189739493835 - i1)) + i0)} [id A] d={0: [0]}
 ← add [id BA] shape=()
    ├─ 2.079441547393799 [id BB] shape=()
    ├─ mul [id BC] shape=()
    │  ├─ 0.5 [id BD] shape=()
    │  └─ sub [id BE] shape=()
    │     ├─ -28.789189739493835 [id BF] shape=()
    │     └─ i1 [id BG] shape=()
    └─ i0 [id BH] shape=()

Composite{((i1 * i2) + i0)} [id Y] d={0: [1]}
 ← add [id BI] shape=()
    ├─ mul [id BJ] shape=()
    │  ├─ i1 [id BG] shape=()
    │  └─ i2 [id BK] shape=()
    └─ i0 [id BH] shape=()

So as good as I can think of

@ricardoV94 ricardoV94 changed the title Unconstrain Wishart Unconstrain transform Wishart Apr 19, 2026
@ricardoV94 ricardoV94 force-pushed the Wishart branch 2 times, most recently from fdcfffe to 9f081e9 Compare April 23, 2026 17:48
@ricardoV94 ricardoV94 marked this pull request as ready for review April 23, 2026 17:52
@ricardoV94 ricardoV94 changed the title Unconstrain transform Wishart Unconstrain transform for Wishart Apr 23, 2026
Comment thread pymc/distributions/transforms.py Outdated
# re-interpreting the value.
raise NotImplementedError(
"`initval` is no longer supported in the WishartBartlett shim. "
"Pass `initval` (as an SPD matrix) to `pm.Wishart` directly."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Pass `initval` (as an SPD matrix) to `pm.Wishart` directly."
"Pass `initval` (as a SPD matrix) to `pm.Wishart` directly."

Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always say PSD not SPD, amd I systematically making a fool of myself?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Positive semi definite vs symmetric positive definite?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't a positive definite matrix necessarily symmetric?

Comment thread tests/distributions/test_multivariate.py Outdated
return pt.sum(value[..., self.diag_idxs], axis=-1)


class CholeskyCovTransform(Transform):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a sense of why this is so much simpler than the Correlation transform? Was I just doing something dumb over there?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe all we needed tril_indices(n, k=-1)? We copied the idea from tfp maybe they didn't have that?

Or they may have been worried about batch/dims inverse/autograd that wasn't well supported?

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have any indexing so there's that... It does still have a tril zeroing out... We should just bench / check the graph. If there's no downside I would jump

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll revisit it later when I have the time/inclination. It's nice to have this as a simpler alternative.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More interestingly the Stan implementation that provides nicer geometry (IIRC)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph looks much more clean with pt.tril indexing... but still shows some other simplifications (including your extract_diag(elemwise), and dumb stuff like sqr(abs(x)), when x is not complex (shows up in linalg.norm). So more low hanging fruit to squeeze when we switch. Not in this PR obviously

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one looks more clean i'm confused

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an approach like the one here, instead of the spiral thing

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah ok, that's what I was thinking

@ricardoV94 ricardoV94 merged commit fc7493a into pymc-devs:v6 May 4, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DOC: Clarify the pm.Wishart warning (“unusable in a PyMC model”)

2 participants