-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Find and create out-of-scope cell in LOAD_CLOSURE #101357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101357
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c0f0308: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| return deep(x) | ||
|
|
||
| torch._dynamo.export(indirect, torch.randn(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please check if the generated code is correct, by (say) running eager and comparing?
| if isinstance(sym_local, variables.NewCellVariable): | ||
| self.push(side_effects.load_cell(sym_local)) | ||
| else: | ||
| cell_var = side_effects.track_cell_new() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think you need a test with cond in this case. That was the original motivation, I simplified the example to avoid cond. Specifically, I'm worried that if two branches read different closure variables (even though nobody is writing to them), we'll end up with different side effects.
In any case, some comments here documenting the issues and the plan to resolve them would be nice.
Also, I'm not sure why these new cell variables cannot be pre-allocated before we reach here. Naively I'd assume that we know all the closure variables before attempting to trace.
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though: ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 [ghstack-poisoned]
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though: ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 [ghstack-poisoned]
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken): ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` Pull Request resolved: #104222 Approved by: https://github.com/jansel
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #99639
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov @soumith @desertfire