Skip to content

Conversation

@benjaminglass1
Copy link
Collaborator

@benjaminglass1 benjaminglass1 commented Jul 30, 2025

Stack from ghstack (oldest at bottom):

  1. Ensures that any subgraphs on a GraphModule are updated by FakeTensorUpdater.
  2. Ensures that any users of those subgraphs (i.e. invoke_subgraph) also get updated appropriately.
  3. Enables processing of HOPs by FakeTensorUpdater.
  4. Adds tests for the use of HOPs within FakeTensorUpdater managed graphs.

Fixes #156819

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159523

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 621684a with merge base 20cae80 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

benjaminglass1 added a commit that referenced this pull request Jul 30, 2025
Fixes #156819

ghstack-source-id: 1bd19d6
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Jul 30, 2025
Fixes #156819

ghstack-source-id: e9e0d9a
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Jul 31, 2025
Fixes #156819

ghstack-source-id: 8f96c2c
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Aug 1, 2025
Fixes #156819

ghstack-source-id: 814e074
Pull Request resolved: #159523
@benjaminglass1 benjaminglass1 self-assigned this Aug 1, 2025
@benjaminglass1 benjaminglass1 changed the title [inductor] Fix FakeTensorUpdater handling of HOPs [WIP][inductor] Fix FakeTensorUpdater handling of HOPs Aug 2, 2025
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Aug 2, 2025
Fixes #156819

ghstack-source-id: f244405
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Aug 4, 2025
Fixes #156819

ghstack-source-id: e8c0ac5
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Aug 14, 2025
Fixes [#156819](#156819)

ghstack-source-id: 2aff2b5
Pull Request resolved: #159523
@benjaminglass1 benjaminglass1 marked this pull request as ready for review August 14, 2025 22:30
@benjaminglass1 benjaminglass1 changed the title [WIP][inductor] Fix FakeTensorUpdater handling of HOPs [inductor] Fix FakeTensorUpdater handling of HOPs Aug 14, 2025
@eellison eellison requested a review from zou3519 August 15, 2025 21:24
@eellison
Copy link
Contributor

@zou3519 said he would review - deferring to him on this one.

@eellison eellison removed their request for review August 22, 2025 20:02
benjaminglass1 added a commit that referenced this pull request Oct 17, 2025
Fixes [#156819](#156819)

ghstack-source-id: 7a45b59
Pull Request resolved: #159523
@benjaminglass1
Copy link
Collaborator Author

Spoke with @zou3519 offline, and concluded that the requests for better testing can be done in follow-up PRs, given the existing lack of testing for FakeTensorUpdater, so that we can get the UB fixed.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example_val thing is still suspicious to me. If you could send an example test case where you saw it that would be helpful

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Oct 29, 2025
Fixes [#156819](#156819)

ghstack-source-id: 08a1878
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 1, 2025
Fixes [#156819](#156819)

ghstack-source-id: 1f3208c
Pull Request resolved: #159523
@benjaminglass1
Copy link
Collaborator Author

Pushed a version that updates subgraphs with new placeholder arguments when needed, at the time we process the subgraph invocation. It appears to work, but I've run into a new wrinkle: it looks like in some cases (and I've modified the test to reflect this) the subgraph can be called multiple times. I think this may invalidate the idea that we can change the shape and stride of inputs and outputs, at least trivially. Minimally, we need to skip doing this to subgraphs that are invoked repeatedly; maximally, perhaps we should just throw a loud error when changing placeholders or outputs on a graph?

@benjaminglass1
Copy link
Collaborator Author

Writing up where this is at, for visibility:

  1. I've found cases where the same subgraph member gets reused multiple times within a graph. This means that we cannot unconditionally update inputs to a subgraph with differently shaped FakeTensors, since that could happen with different shapes at different points in the graph.
  2. I've (locally) added checks that we aren't updating the same subgraph twice, but these checks have been growing more complicated as I try to avoid triggering them for tensors with the same shape/stride/etc (since in those cases it's entirely valid to reuse the subgraph).
  3. The final obstacle is that simply reusing the subgraph.output_node().meta["val"] tensors to represent the output of the subgraph results in apparent aliasing where there should not be any. This may mean that we always need to re-run FakeTensor propagation through subgraphs, which I would prefer to avoid for computational reasons. I'm still working on a reliable way to copy those updated tensors without losing real aliasing relationships.

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 8, 2025
Fixes [#156819](#156819)

ghstack-source-id: 334ed1d
Pull Request resolved: #159523
@benjaminglass1
Copy link
Collaborator Author

benjaminglass1 commented Nov 8, 2025

Update: this is very close now; the only remaining obstacle (pending the test run working, obviously) is handling HOPs that did not initially appear to utilize subgraphs, like torch.cond. These also need to be updated, so this code needs to become more generalized to handle different arg-passing formats.

EDIT: @zou3519 re-requesting your review to look at the current approach and make sure there's no obvious issues I overlooked.

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 12, 2025
Fixes [#156819](#156819)

ghstack-source-id: ebb2943
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 14, 2025
Fixes [#156819](#156819)

ghstack-source-id: 56d4ded
Pull Request resolved: #159523
@benjaminglass1 benjaminglass1 added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 14, 2025
Comment on lines +272 to +329
def extract_subgraphs_and_args(
node: torch.fx.Node, *args: Any, **kwargs: Any
) -> tuple[tuple[torch.fx.GraphModule, ...], tuple[Any, ...] | None]:
"""HOPs that invoke subgraphs take a number of different forms. This
function regularizes them, returning a tuple of subgraphs contained in the
args and a tuple of the args for the subgraphs. This function assumes all
subgraphs share a set of common arguments.
This function assumes that node_invokes_subgraph(node, *args, **kwargs) is
True.
If the second return value is None, this function was unable to determine
what args to pass to the subgraph(s)."""
if node.target is torch.ops.higher_order.cond:
return tuple(args[1:3]), tuple(args[3])
if node.target is torch.ops.higher_order.foreach_map:
return (args[0],), tuple(args[1:])
if node.target in (
torch.ops.higher_order.invoke_quant_packed,
torch.ops.higher_order.invoke_quant,
):
return (args[0],), tuple(args[1:])
if node.target is torch.ops.higher_order.invoke_subgraph:
return (args[0],), tuple(args[2:])
if node.target is torch.ops.higher_order.map_impl:
# map is applied over slices from the first dimension of each value in
# args[1].
return (args[0],), (*(a[0] for a in args[1]), *args[2:])
if node.target in (
torch.ops.higher_order.while_loop,
torch.ops.higher_order.while_loop_stack_output,
):
return tuple(args[:2]), (*args[2], *args[3])
if node.target is control_deps:
assert not kwargs, (
"Subgraph arguments can be renamed, so we cannot consistently "
"handle kwargs at this point in the stack."
)
return (args[1],), tuple(args[2:])
# These functions don't have clean mappings from node arguments to subgraph
# inputs, since those mappings are dependent on details of the original
# invocation that are not preserved. Skip them intentionally.
if node.target not in (
torch.ops.higher_order.associative_scan,
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.scan,
):
warnings.warn(
f"Please add support for subgraph args to function {node.target}!"
)

# By default, just return the detected list of subgraphs so that we can run
# updates on all of them.
return tuple(
s
for s in pytree.tree_flatten(args)
if isinstance(s, torch.fx.GraphModule) and s in self.subgraph_updaters
), None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like how this turned out, but I'm not sure of any other way to do this. Every HOP seems to pass the subgraph args through the call in a different way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure what to do about this. Maybe we can ship a version of this PR first where we don't update the inside of the HOP if the outside of the HOP changes. If we really need this, then we need a way for each HOP to register how to do a FakeTensorUpdater on it, which is really annoying. Thoughts @eellison?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this is fine. Let's say we have y = invoke_subgraph(subgraph, x). The question is if we need to correct the faketensors in subgraph given a new x.

  • At the very least, we need to change y, given a new x.
  • What does "a new x" mean? I'm assuming it has the same shape, but different strides. If it has different shape, then I'd expect the user needs to change the subgraph themselves or the subgraph is trivial.
  • The subgraph should be resilient to changes in strides. If there is a custom operator that depends on the stride being a certain way, then it will emit some code that will coerce the strides to be what it expects

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 18, 2025
Fixes [#156819](#156819)

ghstack-source-id: 039368d
Pull Request resolved: #159523
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Nov 19, 2025
Fixes [#156819](#156819)

ghstack-source-id: 3da1a42
Pull Request resolved: #159523
@pytorch-bot pytorch-bot bot added ciflow/b200 ciflow/h100 ciflow/rocm Trigger "default" config CI on ROCm labels Nov 19, 2025
Comment on lines +131 to +133
strict: disabling this flag will cause this function to only evaluate size,
layout, stride, and device. This is used to validate that arguments are
equivalent enough for updating subgraphs."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do you use strict vs not strict? comment could be clearer on this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I'll clarify the comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants