-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor] Fix FakeTensorUpdater handling of HOPs #159523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/benjaminglass1/97/base
Are you sure you want to change the base?
[inductor] Fix FakeTensorUpdater handling of HOPs #159523
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159523
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 621684a with merge base 20cae80 ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@zou3519 said he would review - deferring to him on this one. |
|
Spoke with @zou3519 offline, and concluded that the requests for better testing can be done in follow-up PRs, given the existing lack of testing for |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example_val thing is still suspicious to me. If you could send an example test case where you saw it that would be helpful
|
Pushed a version that updates subgraphs with new placeholder arguments when needed, at the time we process the subgraph invocation. It appears to work, but I've run into a new wrinkle: it looks like in some cases (and I've modified the test to reflect this) the subgraph can be called multiple times. I think this may invalidate the idea that we can change the shape and stride of inputs and outputs, at least trivially. Minimally, we need to skip doing this to subgraphs that are invoked repeatedly; maximally, perhaps we should just throw a loud error when changing placeholders or outputs on a graph? |
|
Writing up where this is at, for visibility:
|
|
Update: this is very close now; the only remaining obstacle (pending the test run working, obviously) is handling HOPs that did not initially appear to utilize subgraphs, like EDIT: @zou3519 re-requesting your review to look at the current approach and make sure there's no obvious issues I overlooked. |
| def extract_subgraphs_and_args( | ||
| node: torch.fx.Node, *args: Any, **kwargs: Any | ||
| ) -> tuple[tuple[torch.fx.GraphModule, ...], tuple[Any, ...] | None]: | ||
| """HOPs that invoke subgraphs take a number of different forms. This | ||
| function regularizes them, returning a tuple of subgraphs contained in the | ||
| args and a tuple of the args for the subgraphs. This function assumes all | ||
| subgraphs share a set of common arguments. | ||
| This function assumes that node_invokes_subgraph(node, *args, **kwargs) is | ||
| True. | ||
| If the second return value is None, this function was unable to determine | ||
| what args to pass to the subgraph(s).""" | ||
| if node.target is torch.ops.higher_order.cond: | ||
| return tuple(args[1:3]), tuple(args[3]) | ||
| if node.target is torch.ops.higher_order.foreach_map: | ||
| return (args[0],), tuple(args[1:]) | ||
| if node.target in ( | ||
| torch.ops.higher_order.invoke_quant_packed, | ||
| torch.ops.higher_order.invoke_quant, | ||
| ): | ||
| return (args[0],), tuple(args[1:]) | ||
| if node.target is torch.ops.higher_order.invoke_subgraph: | ||
| return (args[0],), tuple(args[2:]) | ||
| if node.target is torch.ops.higher_order.map_impl: | ||
| # map is applied over slices from the first dimension of each value in | ||
| # args[1]. | ||
| return (args[0],), (*(a[0] for a in args[1]), *args[2:]) | ||
| if node.target in ( | ||
| torch.ops.higher_order.while_loop, | ||
| torch.ops.higher_order.while_loop_stack_output, | ||
| ): | ||
| return tuple(args[:2]), (*args[2], *args[3]) | ||
| if node.target is control_deps: | ||
| assert not kwargs, ( | ||
| "Subgraph arguments can be renamed, so we cannot consistently " | ||
| "handle kwargs at this point in the stack." | ||
| ) | ||
| return (args[1],), tuple(args[2:]) | ||
| # These functions don't have clean mappings from node arguments to subgraph | ||
| # inputs, since those mappings are dependent on details of the original | ||
| # invocation that are not preserved. Skip them intentionally. | ||
| if node.target not in ( | ||
| torch.ops.higher_order.associative_scan, | ||
| torch.ops.higher_order.flex_attention, | ||
| torch.ops.higher_order.scan, | ||
| ): | ||
| warnings.warn( | ||
| f"Please add support for subgraph args to function {node.target}!" | ||
| ) | ||
|
|
||
| # By default, just return the detected list of subgraphs so that we can run | ||
| # updates on all of them. | ||
| return tuple( | ||
| s | ||
| for s in pytree.tree_flatten(args) | ||
| if isinstance(s, torch.fx.GraphModule) and s in self.subgraph_updaters | ||
| ), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like how this turned out, but I'm not sure of any other way to do this. Every HOP seems to pass the subgraph args through the call in a different way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure what to do about this. Maybe we can ship a version of this PR first where we don't update the inside of the HOP if the outside of the HOP changes. If we really need this, then we need a way for each HOP to register how to do a FakeTensorUpdater on it, which is really annoying. Thoughts @eellison?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think this is fine. Let's say we have y = invoke_subgraph(subgraph, x). The question is if we need to correct the faketensors in subgraph given a new x.
- At the very least, we need to change
y, given a newx. - What does "a new x" mean? I'm assuming it has the same shape, but different strides. If it has different shape, then I'd expect the user needs to change the subgraph themselves or the subgraph is trivial.
- The subgraph should be resilient to changes in strides. If there is a custom operator that depends on the stride being a certain way, then it will emit some code that will coerce the strides to be what it expects
| strict: disabling this flag will cause this function to only evaluate size, | ||
| layout, stride, and device. This is used to validate that arguments are | ||
| equivalent enough for updating subgraphs.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do you use strict vs not strict? comment could be clearer on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 I'll clarify the comment.
Stack from ghstack (oldest at bottom):
GraphModuleare updated byFakeTensorUpdater.invoke_subgraph) also get updated appropriately.FakeTensorUpdater.FakeTensorUpdatermanaged graphs.Fixes #156819
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos