Unconstrain transform for Wishart#8246
Conversation
Documentation build overview
259 files changed ·
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## v6 #8246 +/- ##
=====================================
Coverage ? 91.97%
=====================================
Files ? 124
Lines ? 20272
Branches ? 0
=====================================
Hits ? 18645
Misses ? 1627
Partials ? 0
🚀 New features to boost your workflow:
|
|
Took a look at the compiled logp+dlogp, and we pay some price for the whole matrix construction. Edit: Graph is pretty clean now (see below), collapsing DetailsFor an Full logp+dlogp graph1. Diagonal gradient routed through an
|
|
With a few general rewrites (already upstreamed in pymc-devs/pytensor#2061), got it down to this form: import pymc as pm
import numpy as np
rng = np.random.default_rng(1)
A = rng.normal(size=(n, n))
V = A @ A.T + n * np.eye(n)
with pm.Model() as m:
pm.Wishart("Sigma", nu=4, V=V)
m.logp_dlogp_function()._pytensor_function.dprint(print_shape=True, print_memory_map=True)So as good as I can think of |
fdcfffe to
9f081e9
Compare
| # re-interpreting the value. | ||
| raise NotImplementedError( | ||
| "`initval` is no longer supported in the WishartBartlett shim. " | ||
| "Pass `initval` (as an SPD matrix) to `pm.Wishart` directly." |
There was a problem hiding this comment.
| "Pass `initval` (as an SPD matrix) to `pm.Wishart` directly." | |
| "Pass `initval` (as a SPD matrix) to `pm.Wishart` directly." |
There was a problem hiding this comment.
I always say PSD not SPD, amd I systematically making a fool of myself?
There was a problem hiding this comment.
Positive semi definite vs symmetric positive definite?
There was a problem hiding this comment.
isn't a positive definite matrix necessarily symmetric?
| return pt.sum(value[..., self.diag_idxs], axis=-1) | ||
|
|
||
|
|
||
| class CholeskyCovTransform(Transform): |
There was a problem hiding this comment.
Do you have a sense of why this is so much simpler than the Correlation transform? Was I just doing something dumb over there?
There was a problem hiding this comment.
Maybe all we needed tril_indices(n, k=-1)? We copied the idea from tfp maybe they didn't have that?
Or they may have been worried about batch/dims inverse/autograd that wasn't well supported?
There was a problem hiding this comment.
It doesn't have any indexing so there's that... It does still have a tril zeroing out... We should just bench / check the graph. If there's no downside I would jump
There was a problem hiding this comment.
i'll revisit it later when I have the time/inclination. It's nice to have this as a simpler alternative.
There was a problem hiding this comment.
More interestingly the Stan implementation that provides nicer geometry (IIRC)
There was a problem hiding this comment.
graph looks much more clean with pt.tril indexing... but still shows some other simplifications (including your extract_diag(elemwise), and dumb stuff like sqr(abs(x)), when x is not complex (shows up in linalg.norm). So more low hanging fruit to squeeze when we switch. Not in this PR obviously
There was a problem hiding this comment.
which one looks more clean i'm confused
There was a problem hiding this comment.
an approach like the one here, instead of the spiral thing
There was a problem hiding this comment.
Yeah ok, that's what I was thinking
Similar idea as #7380 (but this is actually simpler). Almost the same rewrite as LKJCholeskyCov, except here we unconstrain to the full dense matrix.
Closes #8196 (it's now usable)