Skip to content

Fix broadcast check on log_jac_det#8174

Merged
ricardoV94 merged 1 commit into
pymc-devs:mainfrom
velochy:StudentTBroadcastFix
Mar 7, 2026
Merged

Fix broadcast check on log_jac_det#8174
ricardoV94 merged 1 commit into
pymc-devs:mainfrom
velochy:StudentTBroadcastFix

Conversation

@velochy
Copy link
Copy Markdown
Contributor

@velochy velochy commented Mar 6, 2026

Description

Fix for #8173

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@github-actions github-actions Bot added the bug label Mar 6, 2026
@velochy velochy changed the base branch from v6 to main March 6, 2026 12:36
Comment thread pymc/logprob/transform_value.py Outdated
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
# Check there is no broadcasting between logp and jacobian
# Check there is no broadcasting between logp and jacobian.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment is too verbose and specific

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. AIs tend to be verbose with these...

Comment thread pymc/logprob/transform_value.py
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented Mar 6, 2026

Documentation build overview

📚 pymc | 🛠️ Build #31707242 | 📁 Comparing 262f73c against latest (c9cd0c8)


🔍 Preview build

Show files changed (2 files in total): 📝 2 modified | ➕ 0 added | ➖ 0 deleted
File Status
glossary.html 📝 modified
_modules/pytensor/tensor/basic.html 📝 modified

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 6, 2026

Codecov Report

❌ Patch coverage is 75.00000% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.55%. Comparing base (c9cd0c8) to head (262f73c).
⚠️ Report is 110 commits behind head on main.

Files with missing lines Patch % Lines
pymc/logprob/transform_value.py 75.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8174      +/-   ##
==========================================
- Coverage   84.55%   84.55%   -0.01%     
==========================================
  Files         124      124              
  Lines       19865    19872       +7     
==========================================
+ Hits        16797    16802       +5     
- Misses       3068     3070       +2     
Files with missing lines Coverage Δ
pymc/logprob/transform_value.py 96.61% <75.00%> (-1.59%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@velochy velochy force-pushed the StudentTBroadcastFix branch from 964a67c to 12e4efe Compare March 6, 2026 12:53
@velochy
Copy link
Copy Markdown
Contributor Author

velochy commented Mar 6, 2026

@ricardoV94 failing tests look pre-existing. Should I rebase on a branch other than main?

Comment thread pymc/logprob/transform_value.py Outdated
"There is a bug in the implementation of either one."
)

def broadcastable_axes(a, b):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be simplified.
broadcastable_axes = [i for i, (ai, bi) in enumerate(zip(...)) if ai or bi]

And then use in both.

There's no extra cost of specifying a broadcastable_axes that is already known to be broadcastable.

Only thing to make sure is that ndim matches

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand the ndim question here. Is it possible they have different numbers of dimensions? If so, how should they be matched?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may already have logic above that handles different ndim.

If not we should fail, it's a guaranteed form of broadcasting. But the current checks may miss on it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I just add a length equality check and throw an exception if there is a mismatch?
Or no action needed here?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to do, it's handled above with if log_jac_det.ndim < logp.ndim and elif branches above. Just put a strict=True in the zip in case we modify those and forget to handle it here

Comment thread pymc/logprob/transform_value.py
@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 failing tests look pre-existing. Should I rebase on a branch other than main?

I think main is still failing, it's addressed by a PR that's still open. We can ignore

@velochy velochy force-pushed the StudentTBroadcastFix branch from 12e4efe to 7ae9b5c Compare March 6, 2026 13:07
Comment thread pymc/logprob/transform_value.py Outdated
Comment on lines +124 to +128
except ValueError:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except ValueError:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)
except ValueError as err:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
) from err

Comment thread pymc/logprob/transform_value.py Outdated
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
# Check there is no broadcasting between logp and jacobian
# Axes where one side is broadcastable and the other isn't must be size-1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still prefer the original comment

@velochy velochy force-pushed the StudentTBroadcastFix branch from 7ae9b5c to 189873a Compare March 6, 2026 23:37
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, other than nit about comment.

More importantly, can you add a regression test? Your thing with frozen length 1 dim is fine, just calling model.logp(sum=False) and checking it comes back with the right static shape (and implicitly that it didn't raise)

Comment thread pymc/logprob/transform_value.py Outdated
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
# Check there is no broadcasting between logp and jacobian
# Check there is no broadcasting difference between logp and jacobian
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to hold us on this comment, but broadcasting is an action between two variables, not a property of each variable variables (hence why it's called type.broadcastable). We are checking/enforcing that the variables won't broadcast together.

So original comment is still more accurate than current one

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted. Kept the original comment unchanged now.

@velochy velochy force-pushed the StudentTBroadcastFix branch from 189873a to 39909c6 Compare March 7, 2026 15:18
@ricardoV94
Copy link
Copy Markdown
Member

Looks like this test needa updating FAILED tests/distributions/test_transform.py::test_invalid_jacobian_broadcast_raises - Failed: DID NOT RAISE <class 'ValueError'>

@velochy
Copy link
Copy Markdown
Contributor Author

velochy commented Mar 7, 2026

Error moved from compile time to run time, so changed the test to compile and then call the logp instead. Hope that works for you @ricardoV94

@velochy velochy force-pushed the StudentTBroadcastFix branch from 39909c6 to 262f73c Compare March 7, 2026 20:31
@ricardoV94 ricardoV94 changed the title Fix broadcast dimensions on log_jac_det Fix broadcast check on log_jac_det Mar 7, 2026
@ricardoV94 ricardoV94 merged commit 9082a04 into pymc-devs:main Mar 7, 2026
36 of 42 checks passed
@ricardoV94
Copy link
Copy Markdown
Member

Thanks @velochy. There's another PR I want to get merged and then I'll cut a release

@velochy
Copy link
Copy Markdown
Contributor Author

velochy commented Mar 7, 2026

Much appreciated, and thank you again @ricardoV94

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: freeze_dims_and_data breaks HalfStudentT logp when a dimension has size 1

2 participants