-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[dynamo] Support control flow map() operator. #91939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91939
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3cfcad6: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_dynamo/variables/torch.py
Outdated
| tx.output.register_attr_or_module(gm, next_name, source=src) | ||
| return next_name | ||
|
|
||
| def make_subgraph(f, sub_args, graph_checkpoint, checkpoint): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have "speculate" in its name as it doesn't actually apply the changes to Dynamo state
torch/_dynamo/variables/torch.py
Outdated
| body_nn_modules, | ||
| body_cmp, | ||
| ) = make_subgraph(args[0], [ | ||
| wrap_fx_proxy(tx, args[1].as_proxy()[0], **VariableTracker.propagate(args[1])), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks suspicious. What's going on here.
torch/_dynamo/variables/torch.py
Outdated
| "body", torch.fx.GraphModule(body_nn_modules, body_graph) | ||
| ) | ||
|
|
||
| # Apply side effects (guaranteed to be equal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is out of date, you only ran one branch
torch/_dynamo/variables/torch.py
Outdated
| *(arg.as_proxy() for arg in args[1:]) | ||
| ) | ||
| r = body_r.as_proxy().node.meta["example_value"] | ||
| example_value = r.new_empty([get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadly speaking, the implementation here looks like it's cargo culted from cond's implementation, but this is not really appropriate, because map is quite different from cond in two respects:
- There's only one lambda, but it can get run multiple times
- The input/output of the lambda diverge from the outer context
(1) means that you need to, for example, assert that there are no side effects in the lambda (or that the side effects are idempotent or something). Suppose that inside the lambda you call nonlocal x; x += 1. The eager mode map semantics will increment this counter every iteration of the loop. But what you have implemented here applies the side effects once, and then is done. You will have the wrong semantics in this case.
(2) appears to be what is going on with the weirdness with the making fx proxy tensor and then the stuff with new_empty to make the example value. I can believe that the implementation as is works in most cases, but it needs to be documented far better. In particular, when you say x[0] you are leaning on the fact that make_subgraph doesn't actually use the passed in proxies in an interesting way; instead, converts them into nodes. Otherwise, it would be incorrect to say that you had called the lambda with x[0]. In fact, I think I might prefer that we not pass in a proxy (except maybe to make the names nicer) because the implementation here will not work if you map over a tensor with size (0, *sizes) (since x[0] will fail in this case.)
I will expect tests for all of these edge cases.
@ezyang All these makes a lot of sense, thanks! |
|
It's easy to say "oh, I will just check that there are no side effects in the loop body" but Dynamo does model some operations (like accessing closed over variables) as side effects. In any case, try asserting no side effects and see if it errors or not on the models you care about. Re (2), you understand correctly. You actually point out a good point though, which is that zero size input cannot work anyway, as you MUST run the lambda to actually get the output shape, but the lambda is unrunnable if you have no samples. So I guess map() as defined here cannot work with zero size, and so maybe accessing the first element is fine. Better add a check wrt though... |
43707dc to
2c42b59
Compare
|
Sorry for the late update. Addressed comments with a few updates:
cc @ezyang |
torch/_dynamo/variables/torch.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope this access doesn't induce a write to the FX graph
torch/_dynamo/variables/torch.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of body_cmp is ignored here. Maybe it is safer to factor out the comparable state into its own function, and then you can get the comparable state prior to running the lambda and then do a full comparison there.
2c42b59 to
3cfcad6
Compare
|
updates:
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
|
@pytorchbot merge |
Merge failedReason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #ISSUE_NUMBER
We want to add support for control flow map() at dynamo level to unblock some internal model which will have to use map() operator in captured graph. Basically I replicate the pattern for implementing cond() op from #90286
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire