Skip to content

Rewrites for consecutive advanced read-write operations#2061

Merged
ricardoV94 merged 9 commits into
pymc-devs:v3from
ricardoV94:inc_subtensor_rewrite
Apr 22, 2026
Merged

Rewrites for consecutive advanced read-write operations#2061
ricardoV94 merged 9 commits into
pymc-devs:v3from
ricardoV94:inc_subtensor_rewrite

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 19, 2026

Handling simplifications that show up in "sparse"-like graphs (lower triangular with exponentiated diagonals). The gradient / logp can end up with reading entries from a latent vector that was written/updated into a matrix. The rewrites in this PR not materializing the intermediate matrix, by mapping the reads/updates directly into the original latent vector.

  • Generalizes symbolic x[idx].set/inc(y)[idx] -> y (or x[idx] + y) to AdvancedIncSubtensor, not just basic and AdvancedSubtensor1 like before
  • Constant read of constant write like x[const_idx1].set/inc(y)[const_idx2], get's converted into a read of y (when no entry of uwritten x was kept), or a read of x (when no entry of written y was kept) or a mix read of x with some entries still updated by y. Overlap / translation of indices needed, so only applicable to constant indices.
  • Apply the last rewrite to x[idx].set/inc(y).diagonal() (i.e., ExtractDiagonal), by converting to advanced indexing and calling the rewrite
  • Flatten successive writes into the same indices x[idx].set/inc(y)[idx].set/inc(z) -> x[idx].set(y) or (x[idx].inc(y + z))
  • x + zeros()[idx].set(y) -> x[idx].inc(y). Don't materialize the zeros + wasteful addition, update directly into x.

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch from 35413ef to 438beac Compare April 19, 2026 13:53
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
.. code::

sub(set(x, v, idx), idx) -> v
sub(inc(x, v, idx), idx) -> x[idx] + v (idx must be duplicate-free)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: arange(n) (or any other regular index creator) that we now doesn't make repeated indices should also be supported

@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch from 438beac to bc19b4f Compare April 19, 2026 21:17
@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch 2 times, most recently from 6e9d9dd to 6059281 Compare April 19, 2026 21:46
@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch 2 times, most recently from 5e93f56 to c7ce1df Compare April 20, 2026 22:27
@ricardoV94 ricardoV94 changed the title Rewrites for nested advanced read-write operations Rewrites for consecutive advanced read-write operations Apr 20, 2026
Baseline integration test that builds the gradient of a
Cholesky-parameterised log-density.  The forward graph scatters a packed
vector into a lower-triangular matrix, exponentiates the diagonal, and
computes sum(log(diag(L))) + sum(L @ L.T); its gradient produces nested
inc/set chains and diag-of-scatter patterns that subsequent subtensor
rewrites target.  The number of indexing ops in the compiled gradient is
asserted so later commits can tighten it as rewrites reduce the count.
Previously only the set and zero-inc cases were collapsed.
Replace local_adv_sub1_adv_inc_sub1 (which only handled
AdvancedSubtensor1) with local_read_of_write_same_indices, which
covers Subtensor, AdvancedSubtensor1, and AdvancedSubtensor at the
same indices.  The runtime Assert is dropped in favour of the
shape_unsafe tag.  The inc case is now also supported, gated on
duplicate-free indices via the new _constant_has_unique_indices helper.

The basic-Subtensor rewrite local_subtensor_inc_subtensor is removed
since the generalized rewrite subsumes it.
@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch from c7ce1df to daa43e0 Compare April 21, 2026 13:57
@ricardoV94 ricardoV94 requested review from jessegrabowski and removed request for jessegrabowski April 21, 2026 15:09
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
@ricardoV94 ricardoV94 added the enhancement New feature or request label Apr 21, 2026
@ricardoV94
Copy link
Copy Markdown
Member Author

I'm happy with the state of this one

@ricardoV94 ricardoV94 marked this pull request as ready for review April 21, 2026 15:18
Absorb a sparse write into a surrounding add, avoiding materialising
the dense form of the sparse update before the addition.
Multi-axis read-after-write where read and write may have different
constant indices.  Handles full / no / partial coverage for both set
and inc; respects numpy's adv-axis placement (consecutive vs hoisted)
via _non_consecutive_adv_indexing; rejects cross-sign indices that
might alias; casts v to the buffer dtype so v[lookup] matches the
output type.
Fold nested write ops sharing the same index variables.  Typically
arises from gradient accumulation or user code that writes then
updates the same slice.
advanced_read_of_write_constant_indices
Lower extract_diag(advanced_inc_subtensor(...)) to an arange-pattern
AdvancedSubtensor and delegate to local_advanced_read_of_write_constant_indices.
Re-emits ExtractDiag(base) in the no-coverage case to keep the zero-copy
view semantics.  Handles any offset.

Replaces extract_diag_of_diagonal_set_subtensor, which only handled
arange-only set patterns and didn't cover inc; the new rewrite delegates
to the more general constant-indices machinery.  Also drops the unused
'full' import that was only used by the old rewrite.
@ricardoV94 ricardoV94 force-pushed the inc_subtensor_rewrite branch from daa43e0 to 3bdc156 Compare April 21, 2026 16:34
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with some comments. We had a call going over this code and it hasn't changed too much so I'm comfortable with it broadly. I did a re-read focusing mostly on the tests, happy to go into detail somewhere else if you have concerns.

sparse_candidate.owner
and isinstance(
sparse_candidate.owner.op,
IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about tracking add vs tracking these? I guess they're both super common...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I would keep it simple. It's not one of those obscure linalg ops that only you play with

return False
cached = getattr(idx.tag, "unique_indices", None)
if cached is not None:
return bool(cached)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for something to change underneath the node and for this tag to no longer apply? I guess not because it's just a Constant, but thinking out loud.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if data changes inside the same constant we have way bigger problems to worry about

Comment on lines +1270 to +1273
topo = f.maker.fgraph.toposort()
assert not any(
isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_equal_computation against the expected graph?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for the other topo + check presence tests below

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the whole function is being compiled I don't like to do that because it can change a lot of other stuff (fusion - excluded here), inplace, ... index normalization... and what not. If I'm just testing a single rewrite or a single op function I usually go that route. Also some of these were pre-existing tests.

I wouldn't bother extra cycles on this right now, but if you insist I can fire up the bot

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're the lead dev, i only flagged it because you've flagged it on me. I don't think what you have here is bad.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence, if this PR wasn't 9 commits already and didn't add so many tests I would definitely have jumped on it

isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo
)

def test_inc_symbolic_idx_not_rewritten(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you feel about "rewrite doesn't happen" tests in general? The bot always generates them but I usually remove because it's exercises a code path that's pretty easy to reason about. I can put them in if you like them

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked for them, it's as important to test the cases that apply as the cases that don't. And the inc being possibly duplicate is both sneaky and may be hard to hit in practice that it can go a long time without noticing

@ricardoV94 ricardoV94 merged commit a78437a into pymc-devs:v3 Apr 22, 2026
66 checks passed
@jessegrabowski jessegrabowski mentioned this pull request May 31, 2026
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants