Skip to content

Allow freezing of FunctionGraph for hashing#1908

Merged
ricardoV94 merged 18 commits into
pymc-devs:v3from
jessegrabowski:hashable-inner-graphs
Apr 8, 2026
Merged

Allow freezing of FunctionGraph for hashing#1908
ricardoV94 merged 18 commits into
pymc-devs:v3from
jessegrabowski:hashable-inner-graphs

Conversation

@jessegrabowski

Copy link
Copy Markdown
Member

Closes #1606

LLM disclosure: this PR made heavy use of Claude in the planning and first cut stages, though I was heavily involved. Still, the code should be subject to extra scrutiny as a result.

The purpose of the PR is to refactor Ops with inner graphs to allow comparison. The linked issue has an exhaustive discussion of the factors at play. There was an attempt in the aesara days to attack this, but it was perhaps too aggressive: it cons-hashed all Apply nodes, which necessitated changes across the codebase. @ricardoV94 suggested a weakref dict approach for subgraphs. This is implemented at the Op level. The plan is for Ops that have inner graphs (Composite, ScalarLoop, Scan, OpFromGraph, etc) to have a _cache class attribute, and implement the op-specific logic for caching, pickling, unpickling, etc. It didn't look super generalizable to me at first blush, but we can argue about it maybe.

Changes to FunctionGraph:

  • FunctionGraph now has a method freeze that returns a FrozenFunctionGraph.
  • The FrozenFunctionGraph does cons-hashing of Apply nodes within its scope only
  • It generates a hash based on its inner graph
  • Two FrozenFunctionGraphs with the same inner graph with evaluate to equal, but their Apply nodes won't be references to the same objects (this is the "conservatism" of my approach)

Specific implementation details:

  • The structural_hash of a FrozenFunctionGraph is built from a list of 3-tuples: (name, type, inputs), plus the outputs. For constants, inputs is replaced with the hash of the input data.
  • Equality between FrozenFunctionGraphs is done by comparing hashes, then falling back to equal_computation if the hash misses.

A consequence of the cons-hashing in this approach is that the inner graph is de-duplicated when we call fg.freeze(). So a MergeOptimizer pass is no longer required. Usage is demonstrated on the Composite Op. If we like the approach I can move forward with refactoring other Ops, but I wanted to stop here and discuss the approach.

Code example:

import pytensor.tensor as pt
import pytensor

a, b, c, d = pt.dscalars('a', 'b', 'c', 'd')
eq1 = pt.sin(a) * b ** 2
eq2 = pt.sin(c) * d ** 2

with pytensor.config.change_flags(optimizer_verbose=True):
    f = pytensor.function([a, b, c, d], [eq1, eq2])

f.dprint()

Result:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A] 1
 ├─ a [id B]
 └─ b [id C]
Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D] 0
 ├─ c [id E]
 └─ d [id F]

Inner graphs:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A]
 ← mul [id G]
    ├─ sin [id H]
    │  └─ *0-<float64> [id I]
    └─ sqr [id J]
       └─ *1-<float64> [id K]

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D]
 ← mul [id G]
    └─ ···

@ricardoV94 ricardoV94 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you not go all out?

If you already deduplicate and do internal hash-cons you are one step away from getting hashing for free across different FunctionGraphs. Just do the hash-cons globally. Then FrozenFunctionGrahp([x, y], [foo(x, y)] is equal to another functiongraph if and only if fgraph.outputs == other_fgraph.outputs. No need for recursive hashing or expensive equal_computations.

As it stands you are not doing much better sneaking a default MergeOptimizer at __init__ and adding a FunctionGraph class that has no replace mode.

And cheap hashing/ equality is not just a nice to have, it's really valuable to not slow down compilation. In some of my benchmarks on previous work, some graphs could spend inordinate time on equality checks.

Comments regardless of whether we go:

  • Don't create FrozenFunctionGraph as a subclass of FrozenGraph, let's push the general principle, shared abstract classes, no-subclass of actually realized objects. Then you don't need check_frozen , the methods just don't exist for the frozen subclass.
  • You could create a frozenApply that uses tuple for input/outputs instead of list. That will help ensuring the immutability because all our current rewrite machinery works on the idea of overriding entries in those lists. Accidentally trying to mutate a graph would 99% fail there.

Comment thread pytensor/scalar/basic.py Outdated

@ricardoV94 ricardoV94 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to look good, how are you feeling about it?

Notes:

  • Add a FrozenFunctionGraph.unfreeze(), that yields a FunctionGraph?
  • Really try to avoid the FrozenConstant stuff
  • Ops with inner graph (at least the ones you touched now) should only have a FrozenFunctionGraph internally (not a mutable one as well). Maybe that's already the case.

We need some follow-up issues open:

  • Optimizing OpFromGraph: There should be an explicit rewrite that creates a new OpFromGraph with its updated frozen graph, (so it is also reflected immediately in dprint). We should never do any further rewrites of the internal fgraph during compilation.
  • Scan/Minimize/Root: Use the new FrozenFunctionGraph as well. This should immediately address #1601
  • When compiling OpFromGraph in jitted contexts we should try to avoid recreating inner numba/jax functions when the same OFG is compiled multiple times in a function, this will likely speedup compilation. In the C-backend that already happens due to the caching of _fn. That's how we can deliver on the promised compilations speedups and it's specially relevant for a library like pytensor-ml that may want to chains hundreds of the same "LayerOp"s in sequence

Comment thread pytensor/graph/basic.py Outdated
Comment thread pytensor/graph/basic.py Outdated
Comment thread pytensor/graph/basic.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/scalar/basic.py Outdated
Comment thread pytensor/scalar/basic.py Outdated
Comment thread pytensor/scalar/basic.py Outdated
Comment thread pytensor/scalar/basic.py Outdated
Comment thread tests/compile/test_builders.py
Comment thread pytensor/graph/fg.py Outdated
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch 2 times, most recently from 78ee1a9 to eda51d2 Compare March 8, 2026 19:18
Comment thread pytensor/compile/builders.py
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/tensor/rewriting/elemwise.py Outdated
@ricardoV94

ricardoV94 commented Mar 8, 2026

Copy link
Copy Markdown
Member

I left some comments as I checked the changes. I need to think/discuss a bit about the spec thing, and the desire to have a consistent hashing across runtimes. If you remove that the complexity of this PR drops quite a bit, but maybe this is also fine.

Can you confirm this was only needed for the C-backend, and that it would also work if whatever relies on that called something like __stable_hash__ instead of __hash__, that does the fingerprint / spec thing?

Besides that this PR look amazing, and it's a game changer to working with inner graph ops. We really need those to work well

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from eda51d2 to 7202ca3 Compare March 9, 2026 00:15
@jessegrabowski

Copy link
Copy Markdown
Member Author

I removed the spec stuff and simplified the PR down somewhat.

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch 3 times, most recently from 4a7bea8 to 445731f Compare March 9, 2026 00:48
Comment thread pytensor/graph/basic.py
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/fg.py Outdated
Comment thread pytensor/graph/op.py Outdated
Comment thread tests/compile/test_builders.py
Comment thread pytensor/graph/fg.py Outdated
Comment on lines +1016 to +1017
for i, out in enumerate(frozen_outputs):
out.name = f"o{i}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong? The same variable could be output0 in one graph and output 2 in another? Or are these the dummy Output Ops we put in clients?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@ricardoV94 ricardoV94 force-pushed the hashable-inner-graphs branch 5 times, most recently from a9ac55f to ae73a91 Compare April 7, 2026 09:18
Comment thread pytensor/graph/fg.py Outdated
Comment on lines +1017 to +1020
self.variables: frozenset[Variable] = frozenset(memo.values())
self.apply_nodes: frozenset[Apply] = frozenset(sorted_apply_nodes)
self._clients: dict[Variable, list[ClientType]] | None = None
self._toposort: tuple[Apply, ...] = tuple(sorted_apply_nodes)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pre-computed these (except for clients), because we basically have everything we needed already from our loop.

I made them frozenset/tuple instead.

Comment thread pytensor/graph/fg.py
@property
def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override]
if self._clients is None:
clients: dict[Variable, list[ClientType]] = {v: [] for v in self.variables}

@ricardoV94 ricardoV94 Apr 7, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got rid of the setdefault in the inner loop, speeds up things a bit. We may end with more clients that before, for variables without nodes. I think this is much more robust.

One big difference though between this and the regular FunctionGraph is we don't have the dummy Output Apply in the clients of output vars. I think we should add

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from bafdb84 to 8a77a26 Compare April 8, 2026 01:00
@ricardoV94 ricardoV94 merged commit e2f36d1 into pymc-devs:v3 Apr 8, 2026
64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Equality of Ops with InnerGraph

2 participants