Skip to content

Conversation

@ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Jul 20, 2023

This PR adds a python entry point for cond: which will always use dynamo to inspect udfs i.e. true_fn, false_fn.

Implementation-wise:
The overall plan is:

  1. we use torch.export(cond) to retrieve a computational graph and makes use of exisitng HOO's ability of lifting untracked vars to inputs (i.e. closures) and banning python side-effects.
  2. After we get the graph, we rely on internal infrastructure of dynamo to associate the lifted vars to python variables in user's calling environment.
  3. call torch.ops.higher_order.cond(pred, true_gm, false_gm, args_with_closures)

We make sevaral workarounds to make the above plan work: turn off all dispatch keys that currently don't work with torch.compile: 1. make_fx(torch.compile) 2. functorch related such as: functionalize(torch.compile)

Test Plan:
See added and/or modified tests.

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105679

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 31 New Failures, 5 Unrelated Failures

As of commit 292cc85:

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base 6f036c9:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ydwu4 ydwu4 requested a review from zou3519 July 20, 2023 18:29
@ydwu4 ydwu4 added release notes: export ciflow/trunk Trigger trunk jobs on your pull request labels Jul 20, 2023
@ydwu4 ydwu4 changed the title Make cond always use dynamo to inspect udfs. [WIP] Make cond always use dynamo to inspect udfs. Jul 20, 2023

return pytree.tree_map(proxy_to_python_var, list(lifted_proxies.keys()))


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the InlinedClosureVariable still a problem?

Copy link
Contributor Author

@ydwu4 ydwu4 Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't checked the test that triggers the InlinedClosureVariable not found error. Will check it.

Copy link
Contributor Author

@ydwu4 ydwu4 Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submit a separate pr in #106491.

return cond(args[0], *bind_branch_and_args(gm, pos_args))


def cond_compiled(pred, true_fn, false_fn, args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while attempting to test this locally, I have found the following:

  • cond's args must be a list apparently. We should make it so that it can instead be a tuple before it goes public
  • cond doesn't handle pytree inputs yet. seems OK.
  • Is cond limited to a single output?

Copy link
Contributor Author

@ydwu4 ydwu4 Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. yeah, makes sense! 2. agree 3. yeah, currently it's single output.

return sym_int
size = pytree.tree_map(to_hint, t.size())
stride = pytree.tree_map(to_hint, t.stride())
return torch.empty_strided(size, stride, requires_grad=t.requires_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, what happens if cond captures a variable that is not a concrete Tensor?

@ydwu4 ydwu4 force-pushed the eager_cond branch 2 times, most recently from b7ea29f to 4f612f1 Compare August 1, 2023 17:02
pytorchmergebot pushed a commit that referenced this pull request Aug 3, 2023
#106491)

When inlining a function which loads a closure, its direct parent may not load that closure. So we cannot find the closure name in parent's symbolic locals. In this PR, we fix it by recursively searching the parent instruction translator stack to resolve the closure.

**Background**
When developing #105679, this corner case is triggered. A small repro is added in the test of this pr, where outer is loaded by deep2 but not by deep.
```python
def test_inline_closure_not_loaded_by_parent(self):
    def outer(a):
        return a + 1

    def indirect(x):
        return direct(x)

    def direct(x):
        def deep2(c):
            return outer(c)

        def deep(c):
            return deep2(c)

        return deep(x)

    x = torch.randn(3)
    eager = indirect(x)
    counter = CompileCounter()
    compiled = torch._dynamo.optimize(counter)(indirect)(x)
```

Running the test, we have the following error before the PR:
```
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6584, in test_inline_closure_not_loaded_by_parent
    compiled = torch._dynamo.optimize(counter)(indirect)(x)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 321, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 481, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 130, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 362, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 194, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 417, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2067, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2227, in inline_call_
    sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 471, in bind_args
    result[name] = parent.symbolic_locals[name]
torch._dynamo.exc.InternalTorchDynamoError: outer

from user code:
   File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6570, in indirect
    return direct(x)
  File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6579, in direct
    return deep(x)
  File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6577, in deep
    return deep2(c)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To execute this test, run the following from the base repo dir:
     python test/dynamo/test_misc.py -k test_inline_closure_not_loaded_by_parent

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
---------------------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------------------
frames [('total', 1)]
inline_call []
---------------------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------------------
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping helper /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/yidi/local/pytorch/torch/_dynamo/eval_frame.py
[2023-08-02 15:48:36,561] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
            def indirect(x):
[2023-08-02 15:48:36,591] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (3,) [<DimDynamic.STATIC: 2>] [None]
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6570
                return direct(x)
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF direct []
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [UserFunctionVariable()]
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [UserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6572 (inline depth: 1)
            def direct(x):
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6573 (inline depth: 1)
                def deep2(c):
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE outer []
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [InlinedClosureVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep2 at 0x7fbe4d3666b0, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6573> [TupleVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep2 [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_DEREF deep2 [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 1)
                def deep(c):
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE deep2 []
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [NewCellVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> [TupleVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST deep [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6579 (inline depth: 1)
                return deep(x)
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST deep []
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 2)
                def deep(c):
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6577 (inline depth: 2)
                    return deep2(c)
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF deep2 []
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST c [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
```

Test Plan:
add new test

Pull Request resolved: #106491
Approved by: https://github.com/williamwen42, https://github.com/jansel, https://github.com/zou3519
Comment on lines 132 to 168
# fake the new_args with original args
local_scope = {"L": {**locals(), "new_args": args}, "G": globals()}
pos_args = [eval(name, {}, local_scope) for name in example_names]

# We need to extract true_gm and false_gm from export
# as export won't add sym bool
def bind_branch_and_args(gm, pos_args):
ph2orig = dict(
zip((ph for ph in gm.graph.nodes if ph.op == "placeholder"), pos_args)
)
cond_node = next((n for n in gm.graph.nodes if n.target is cond), None)
assert cond_node
true_gm = getattr(gm, cond_node.args[1].name)
false_gm = getattr(gm, cond_node.args[2].name)
pos_args = []
for arg_node in cond_node.args[3]:
if arg_node.op == "placeholder" and arg_node in ph2orig:
pos_args.append(ph2orig[arg_node])
elif arg_node.op == "get_attr":
pos_args.append(getattr(gm, arg_node.target))
else:
raise RuntimeError(f"Cannot bind to original argumentes for {arg_node}")
return true_gm, false_gm, tuple(pos_args)

return cond(args[0], *bind_branch_and_args(gm, pos_args))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(rzou): read through getattr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open question if this is good or if we should use the "trace through NN modules config option"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace through NN modules config option PR: #103676

@zou3519 zou3519 self-requested a review August 3, 2023 20:23
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2023
This PR adds a **same_signature** flag to dynamo.export.

**Motivation:**
In #105679, we experimented on **using dynamo to inspect the UDFs** for cond in eager mode (without torch.compile). This helps us to normalize the inputs (e.g. lifting closure to inputs) and makes higher order operator more robust (e.g. forbid python side effects) and less error-prone in general.

We decided to use dynamo.export (instead of torch.compile) to do the inspection (pointed out by @voznesenskym @zou3519):
- We'd like a **whole-graph capture** for the UDF.
- We'd like the dynamo inspection to be **stateless**. Using torch.compile would require resetting dynamo context before and after the inspection because the compile flags may be different from users' torch.compile. This will clear all dynamo cache.
- We can still implement some **caching** based on the guards.

However, this requires export to be able to handle the case where it cannot always rewrite signature: e.g. closure lifted as input.

This PR makes the rewrite optional.

**Implementation:**
We just put all the code that are related to signature rewriting into a function called rewrite_signature and use a same_signature flag to optionally to the transformation.

**Test Plan:**
existing tests.

Pull Request resolved: #106569
Approved by: https://github.com/ezyang
@ydwu4 ydwu4 force-pushed the eager_cond branch 2 times, most recently from 494b701 to 4311861 Compare August 9, 2023 20:25
ydwu4 added a commit that referenced this pull request Aug 11, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 11, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 11, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 14, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 14, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 14, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 15, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

Pull Request resolved: #107042
Approved by: https://github.com/ezyang, https://github.com/zou3519
ydwu4 added a commit that referenced this pull request Aug 15, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 15, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
# instrumentation will see the meta conversions and the
# tests all break so we just exclude this. In any case
# the to conversion isn't really right anyhow.
(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code probably shouldn't run for all of the cases listed above? (quantized/nested/sparse, etc)

summerdo pushed a commit to summerdo/pytorch that referenced this pull request Aug 17, 2023
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish pytorch#105679.

Pull Request resolved: pytorch#107042
Approved by: https://github.com/ezyang, https://github.com/zou3519
Cyril-Anto pushed a commit to Cyril-Anto/pytorch that referenced this pull request Aug 17, 2023
…ch#106569)

This PR adds a **same_signature** flag to dynamo.export.

**Motivation:**
In pytorch#105679, we experimented on **using dynamo to inspect the UDFs** for cond in eager mode (without torch.compile). This helps us to normalize the inputs (e.g. lifting closure to inputs) and makes higher order operator more robust (e.g. forbid python side effects) and less error-prone in general.

We decided to use dynamo.export (instead of torch.compile) to do the inspection (pointed out by @voznesenskym @zou3519):
- We'd like a **whole-graph capture** for the UDF.
- We'd like the dynamo inspection to be **stateless**. Using torch.compile would require resetting dynamo context before and after the inspection because the compile flags may be different from users' torch.compile. This will clear all dynamo cache.
- We can still implement some **caching** based on the guards.

However, this requires export to be able to handle the case where it cannot always rewrite signature: e.g. closure lifted as input.

This PR makes the rewrite optional.

**Implementation:**
We just put all the code that are related to signature rewriting into a function called rewrite_signature and use a same_signature flag to optionally to the transformation.

**Test Plan:**
existing tests.

Pull Request resolved: pytorch#106569
Approved by: https://github.com/ezyang
@ydwu4 ydwu4 closed this Aug 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants