-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[dynamo] fix deep nested closure cell KeyError #104222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104222
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit eebc7de: NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though: ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 [ghstack-poisoned]
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though: ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f 'unrelated test failure' |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…nother function See #136814 for more context. In #90286, we introduced an optimization so that for captured cells that are unmodified during a Dynamo trace, `UserFunctionVariable` will represent them as variable of the cell's actual value, rather than a `NewCellVariable`. Later on we introduced more mechanisms to model such cells across function calls (#104222), and across function calls where `NestedUserFunctionVariable::bind_args` need to look up further in the parent frames (#106491) to find these cells' values. This patch removes `InlinedClosureVariable` in favor of a simpler modelling which is also more consistent with what was introduced in #90286, i.e., just model these cells as their contents, in `symbolic_locals`. This fixes #136814 because resolution of `InlinedClosureVariable` to the underlying cell content value happens in `NestedUserFunctionVariable::bind_args`, which requires Dynamo to have the value in scope at the function call site (when Dynamo starts inlining), but's not always the case (as the test case shows). However, if we model the cells in `symbolic_locals`, we never need such resolution, and the values are directly stored into the `NestedUserFunctionVariable::closure` upon the function creation, at which point Dynamo always has the cell value in `symbolic_locals` for look up. Fixes #136814.
…nother function (#137510) See `test_inline_closure_returned_by_another_function_and_captures` and #136814 for more context. In #90286, we introduced an optimization so that for captured cells that are unmodified during a Dynamo trace, `UserFunctionVariable` will represent them as variable of the cell's actual value, rather than a `NewCellVariable`. Later on we introduced more mechanisms to model such cells across function calls (#104222), and across function calls where `NestedUserFunctionVariable::bind_args` need to look up further in the parent frames (#106491) to find these cells' values. This patch removes `InlinedClosureVariable` in favor of a simpler modelling, which is also more consistent with what was introduced in #90286, i.e., just model these cells as their contents, in `symbolic_locals`. This fixes #136814 because resolution of `InlinedClosureVariable` to the underlying cell content value happens in `NestedUserFunctionVariable::bind_args`, which requires Dynamo to have the value in scope at the function call site (when Dynamo does inlining), but's not always the case (as the test case shows). However, if we model the cells in `symbolic_locals`, we never need such resolution, and the values are directly stored into the `NestedUserFunctionVariable::closure` upon the function creation, at which point Dynamo always has the cell value in `symbolic_locals` for look up. Fixes #136814. Pull Request resolved: #137510 Approved by: https://github.com/williamwen42
Stack from ghstack (oldest at bottom):
Fix #99639 by handling the case in
InliningInstructionTranslator'sLOAD_CLOSUREdefinition when the requested cell is not inself.closure_cells.My intuition is that the behavior of
LOAD_DEREFandSTORE_DEREFon a cell/freevar should not depend on whether or not we calledLOAD_CLOSURE(that is, we shouldn't create a new cell var inLOAD_CLOSURElike in #101357). But we need a way to push cells created by the inlined function that were not present in the caller -InlinedClosureVariableis used to differentiate these cells from other cells.Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that
condwith Python side effects is still broken):cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78