Rewrites for consecutive advanced read-write operations#2061
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
35413ef to
438beac
Compare
| .. code:: | ||
|
|
||
| sub(set(x, v, idx), idx) -> v | ||
| sub(inc(x, v, idx), idx) -> x[idx] + v (idx must be duplicate-free) |
There was a problem hiding this comment.
TODO: arange(n) (or any other regular index creator) that we now doesn't make repeated indices should also be supported
438beac to
bc19b4f
Compare
6e9d9dd to
6059281
Compare
5e93f56 to
c7ce1df
Compare
Baseline integration test that builds the gradient of a Cholesky-parameterised log-density. The forward graph scatters a packed vector into a lower-triangular matrix, exponentiates the diagonal, and computes sum(log(diag(L))) + sum(L @ L.T); its gradient produces nested inc/set chains and diag-of-scatter patterns that subsequent subtensor rewrites target. The number of indexing ops in the compiled gradient is asserted so later commits can tighten it as rewrites reduce the count.
Previously only the set and zero-inc cases were collapsed.
Replace local_adv_sub1_adv_inc_sub1 (which only handled AdvancedSubtensor1) with local_read_of_write_same_indices, which covers Subtensor, AdvancedSubtensor1, and AdvancedSubtensor at the same indices. The runtime Assert is dropped in favour of the shape_unsafe tag. The inc case is now also supported, gated on duplicate-free indices via the new _constant_has_unique_indices helper. The basic-Subtensor rewrite local_subtensor_inc_subtensor is removed since the generalized rewrite subsumes it.
c7ce1df to
daa43e0
Compare
|
I'm happy with the state of this one |
Absorb a sparse write into a surrounding add, avoiding materialising the dense form of the sparse update before the addition.
Multi-axis read-after-write where read and write may have different constant indices. Handles full / no / partial coverage for both set and inc; respects numpy's adv-axis placement (consecutive vs hoisted) via _non_consecutive_adv_indexing; rejects cross-sign indices that might alias; casts v to the buffer dtype so v[lookup] matches the output type.
Fold nested write ops sharing the same index variables. Typically arises from gradient accumulation or user code that writes then updates the same slice.
advanced_read_of_write_constant_indices
Lower extract_diag(advanced_inc_subtensor(...)) to an arange-pattern AdvancedSubtensor and delegate to local_advanced_read_of_write_constant_indices. Re-emits ExtractDiag(base) in the no-coverage case to keep the zero-copy view semantics. Handles any offset. Replaces extract_diag_of_diagonal_set_subtensor, which only handled arange-only set patterns and didn't cover inc; the new rewrite delegates to the more general constant-indices machinery. Also drops the unused 'full' import that was only used by the old rewrite.
daa43e0 to
3bdc156
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
Approving with some comments. We had a call going over this code and it hasn't changed too much so I'm comfortable with it broadly. I did a re-read focusing mostly on the tests, happy to go into detail somewhere else if you have concerns.
| sparse_candidate.owner | ||
| and isinstance( | ||
| sparse_candidate.owner.op, | ||
| IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor, |
There was a problem hiding this comment.
what do you think about tracking add vs tracking these? I guess they're both super common...
There was a problem hiding this comment.
yeah, I would keep it simple. It's not one of those obscure linalg ops that only you play with
| return False | ||
| cached = getattr(idx.tag, "unique_indices", None) | ||
| if cached is not None: | ||
| return bool(cached) |
There was a problem hiding this comment.
is it possible for something to change underneath the node and for this tag to no longer apply? I guess not because it's just a Constant, but thinking out loud.
There was a problem hiding this comment.
if data changes inside the same constant we have way bigger problems to worry about
| topo = f.maker.fgraph.toposort() | ||
| assert not any( | ||
| isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo | ||
| ) |
There was a problem hiding this comment.
assert_equal_computation against the expected graph?
There was a problem hiding this comment.
Same question for the other topo + check presence tests below
There was a problem hiding this comment.
When the whole function is being compiled I don't like to do that because it can change a lot of other stuff (fusion - excluded here), inplace, ... index normalization... and what not. If I'm just testing a single rewrite or a single op function I usually go that route. Also some of these were pre-existing tests.
I wouldn't bother extra cycles on this right now, but if you insist I can fire up the bot
There was a problem hiding this comment.
you're the lead dev, i only flagged it because you've flagged it on me. I don't think what you have here is bad.
There was a problem hiding this comment.
I'm on the fence, if this PR wasn't 9 commits already and didn't add so many tests I would definitely have jumped on it
| isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo | ||
| ) | ||
|
|
||
| def test_inc_symbolic_idx_not_rewritten(self): |
There was a problem hiding this comment.
how do you feel about "rewrite doesn't happen" tests in general? The bot always generates them but I usually remove because it's exercises a code path that's pretty easy to reason about. I can put them in if you like them
There was a problem hiding this comment.
I asked for them, it's as important to test the cases that apply as the cases that don't. And the inc being possibly duplicate is both sneaky and may be hard to hit in practice that it can go a long time without noticing
Handling simplifications that show up in "sparse"-like graphs (lower triangular with exponentiated diagonals). The gradient / logp can end up with reading entries from a latent vector that was written/updated into a matrix. The rewrites in this PR not materializing the intermediate matrix, by mapping the reads/updates directly into the original latent vector.