-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[dynamo] resolve InlinedClosureVariable in InstructionTranslator stack #106491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106491
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ef6821f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…nother function See #136814 for more context. In #90286, we introduced an optimization so that for captured cells that are unmodified during a Dynamo trace, `UserFunctionVariable` will represent them as variable of the cell's actual value, rather than a `NewCellVariable`. Later on we introduced more mechanisms to model such cells across function calls (#104222), and across function calls where `NestedUserFunctionVariable::bind_args` need to look up further in the parent frames (#106491) to find these cells' values. This patch removes `InlinedClosureVariable` in favor of a simpler modelling which is also more consistent with what was introduced in #90286, i.e., just model these cells as their contents, in `symbolic_locals`. This fixes #136814 because resolution of `InlinedClosureVariable` to the underlying cell content value happens in `NestedUserFunctionVariable::bind_args`, which requires Dynamo to have the value in scope at the function call site (when Dynamo starts inlining), but's not always the case (as the test case shows). However, if we model the cells in `symbolic_locals`, we never need such resolution, and the values are directly stored into the `NestedUserFunctionVariable::closure` upon the function creation, at which point Dynamo always has the cell value in `symbolic_locals` for look up. Fixes #136814.
…nother function (#137510) See `test_inline_closure_returned_by_another_function_and_captures` and #136814 for more context. In #90286, we introduced an optimization so that for captured cells that are unmodified during a Dynamo trace, `UserFunctionVariable` will represent them as variable of the cell's actual value, rather than a `NewCellVariable`. Later on we introduced more mechanisms to model such cells across function calls (#104222), and across function calls where `NestedUserFunctionVariable::bind_args` need to look up further in the parent frames (#106491) to find these cells' values. This patch removes `InlinedClosureVariable` in favor of a simpler modelling, which is also more consistent with what was introduced in #90286, i.e., just model these cells as their contents, in `symbolic_locals`. This fixes #136814 because resolution of `InlinedClosureVariable` to the underlying cell content value happens in `NestedUserFunctionVariable::bind_args`, which requires Dynamo to have the value in scope at the function call site (when Dynamo does inlining), but's not always the case (as the test case shows). However, if we model the cells in `symbolic_locals`, we never need such resolution, and the values are directly stored into the `NestedUserFunctionVariable::closure` upon the function creation, at which point Dynamo always has the cell value in `symbolic_locals` for look up. Fixes #136814. Pull Request resolved: #137510 Approved by: https://github.com/williamwen42
When inlining a function which loads a closure, its direct parent may not load that closure. So we cannot find the closure name in parent's symbolic locals. In this PR, we fix it by recursively searching the parent instruction translator stack to resolve the closure.
Background
When developing #105679, this corner case is triggered. A small repro is added in the test of this pr, where outer is loaded by deep2 but not by deep.
Running the test, we have the following error before the PR:
Test Plan:
add new test
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov