Fix broadcast check on log_jac_det#8174
Conversation
| f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" | ||
| ) | ||
| # Check there is no broadcasting between logp and jacobian | ||
| # Check there is no broadcasting between logp and jacobian. |
There was a problem hiding this comment.
comment is too verbose and specific
There was a problem hiding this comment.
My bad. AIs tend to be verbose with these...
Documentation build overview
Show files changed (2 files in total): 📝 2 modified | ➕ 0 added | ➖ 0 deleted
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8174 +/- ##
==========================================
- Coverage 84.55% 84.55% -0.01%
==========================================
Files 124 124
Lines 19865 19872 +7
==========================================
+ Hits 16797 16802 +5
- Misses 3068 3070 +2
🚀 New features to boost your workflow:
|
964a67c to
12e4efe
Compare
|
@ricardoV94 failing tests look pre-existing. Should I rebase on a branch other than main? |
| "There is a bug in the implementation of either one." | ||
| ) | ||
|
|
||
| def broadcastable_axes(a, b): |
There was a problem hiding this comment.
can be simplified.
broadcastable_axes = [i for i, (ai, bi) in enumerate(zip(...)) if ai or bi]
And then use in both.
There's no extra cost of specifying a broadcastable_axes that is already known to be broadcastable.
Only thing to make sure is that ndim matches
There was a problem hiding this comment.
not sure I understand the ndim question here. Is it possible they have different numbers of dimensions? If so, how should they be matched?
There was a problem hiding this comment.
We may already have logic above that handles different ndim.
If not we should fail, it's a guaranteed form of broadcasting. But the current checks may miss on it
There was a problem hiding this comment.
So I just add a length equality check and throw an exception if there is a mismatch?
Or no action needed here?
There was a problem hiding this comment.
Nothing to do, it's handled above with if log_jac_det.ndim < logp.ndim and elif branches above. Just put a strict=True in the zip in case we modify those and forget to handle it here
I think main is still failing, it's addressed by a PR that's still open. We can ignore |
12e4efe to
7ae9b5c
Compare
| except ValueError: | ||
| raise ValueError( | ||
| f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " | ||
| "There is a bug in the implementation of either one." | ||
| ) |
There was a problem hiding this comment.
| except ValueError: | |
| raise ValueError( | |
| f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " | |
| "There is a bug in the implementation of either one." | |
| ) | |
| except ValueError as err: | |
| raise ValueError( | |
| f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " | |
| "There is a bug in the implementation of either one." | |
| ) from err |
| f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" | ||
| ) | ||
| # Check there is no broadcasting between logp and jacobian | ||
| # Axes where one side is broadcastable and the other isn't must be size-1 |
There was a problem hiding this comment.
Still prefer the original comment
7ae9b5c to
189873a
Compare
ricardoV94
left a comment
There was a problem hiding this comment.
looks good, other than nit about comment.
More importantly, can you add a regression test? Your thing with frozen length 1 dim is fine, just calling model.logp(sum=False) and checking it comes back with the right static shape (and implicitly that it didn't raise)
| f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" | ||
| ) | ||
| # Check there is no broadcasting between logp and jacobian | ||
| # Check there is no broadcasting difference between logp and jacobian |
There was a problem hiding this comment.
I don't want to hold us on this comment, but broadcasting is an action between two variables, not a property of each variable variables (hence why it's called type.broadcastable). We are checking/enforcing that the variables won't broadcast together.
So original comment is still more accurate than current one
There was a problem hiding this comment.
noted. Kept the original comment unchanged now.
189873a to
39909c6
Compare
|
Looks like this test needa updating FAILED tests/distributions/test_transform.py::test_invalid_jacobian_broadcast_raises - Failed: DID NOT RAISE <class 'ValueError'> |
|
Error moved from compile time to run time, so changed the test to compile and then call the logp instead. Hope that works for you @ricardoV94 |
39909c6 to
262f73c
Compare
|
Thanks @velochy. There's another PR I want to get merged and then I'll cut a release |
|
Much appreciated, and thank you again @ricardoV94 |
Description
Fix for #8173
Related Issue
Checklist
Type of change