Skip to content

Conversation

@knwng
Copy link
Contributor

@knwng knwng commented May 14, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented May 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101357

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c0f0308:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 15, 2023

return deep(x)

torch._dynamo.export(indirect, torch.randn(3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please check if the generated code is correct, by (say) running eager and comparing?

if isinstance(sym_local, variables.NewCellVariable):
self.push(side_effects.load_cell(sym_local))
else:
cell_var = side_effects.track_cell_new()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think you need a test with cond in this case. That was the original motivation, I simplified the example to avoid cond. Specifically, I'm worried that if two branches read different closure variables (even though nobody is writing to them), we'll end up with different side effects.

In any case, some comments here documenting the issues and the plan to resolve them would be nice.

Also, I'm not sure why these new cell variables cannot be pre-allocated before we reach here. Naively I'd assume that we know all the closure variables before attempting to trace.

williamwen42 added a commit that referenced this pull request Jun 27, 2023
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.

My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.

Adding this test causes an error though:
```python
    def test_closure_out_of_scope_cell_with_cond(self):
        from functorch.experimental.control_flow import cond
        cell1 = torch.rand(3, 3)
        cell2 = torch.rand(3, 3)
        orig3 = torch.rand(3, 3)
        def test(x):
            cell3 = orig3.clone()
            def then():
                nonlocal cell3
                cell3 += cell1
                return cell3
            def els():
                nonlocal cell3
                cell3 += cell2
                return cell3
            return cond(x > 0, then, els, [])
        opt_fn = torch._dynamo.optimize("eager")(test)
        result1 = opt_fn(1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1))
        result2 = opt_fn(-1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
    result1 = opt_fn(1)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
    super().run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
    return inner_fn(self, inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
    output = f.call_function(tx, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
    return tx.inline_user_function_return(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
    proxy = tx.output.create_proxy(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
    return self.current_tracer.create_proxy(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
    new_arg = self.lift_tracked_freevar_to_input(arg)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
    self.parent.lift_tracked_freevar_to_input(proxy)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
    assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer

from user code:
   File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
    return cond(x > 0, then, els, [])
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
    cell3 += cell2
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jun 27, 2023
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.

My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.

Adding this test causes an error though:
```python
    def test_closure_out_of_scope_cell_with_cond(self):
        from functorch.experimental.control_flow import cond
        cell1 = torch.rand(3, 3)
        cell2 = torch.rand(3, 3)
        orig3 = torch.rand(3, 3)
        def test(x):
            cell3 = orig3.clone()
            def then():
                nonlocal cell3
                cell3 += cell1
                return cell3
            def els():
                nonlocal cell3
                cell3 += cell2
                return cell3
            return cond(x > 0, then, els, [])
        opt_fn = torch._dynamo.optimize("eager")(test)
        result1 = opt_fn(1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1))
        result2 = opt_fn(-1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
    result1 = opt_fn(1)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
    super().run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
    return inner_fn(self, inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
    output = f.call_function(tx, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
    return tx.inline_user_function_return(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
    proxy = tx.output.create_proxy(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
    return self.current_tracer.create_proxy(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
    new_arg = self.lift_tracked_freevar_to_input(arg)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
    self.parent.lift_tracked_freevar_to_input(proxy)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
    assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer

from user code:
   File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
    return cond(x > 0, then, els, [])
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
    cell3 += cell2
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 28, 2023
Fix #99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.

My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in #101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.

Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken):
```python
    def test_closure_out_of_scope_cell_with_cond(self):
        from functorch.experimental.control_flow import cond
        cell1 = torch.rand(3, 3)
        cell2 = torch.rand(3, 3)
        orig3 = torch.rand(3, 3)
        def test(x):
            cell3 = orig3.clone()
            def then():
                nonlocal cell3
                cell3 += cell1
                return cell3
            def els():
                nonlocal cell3
                cell3 += cell2
                return cell3
            return cond(x > 0, then, els, [])
        opt_fn = torch._dynamo.optimize("eager")(test)
        result1 = opt_fn(1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1))
        result2 = opt_fn(-1)
        self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
    result1 = opt_fn(1)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
    super().run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
    return inner_fn(self, inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
    (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
    ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
    output = f.call_function(tx, args, {})
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
    return tx.inline_user_function_return(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
    tracer.run()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
    proxy = tx.output.create_proxy(
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
    return self.current_tracer.create_proxy(*args, **kwargs)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
    new_arg = self.lift_tracked_freevar_to_input(arg)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
    self.parent.lift_tracked_freevar_to_input(proxy)
  File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
    assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer

from user code:
   File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
    return cond(x > 0, then, els, [])
  File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
    cell3 += cell2
```

Pull Request resolved: #104222
Approved by: https://github.com/jansel
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 14, 2023
@github-actions github-actions bot closed this Aug 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: dynamo open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Sufficiently deep nesting + inlining + closure fails with KeyError

4 participants