-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[WIP] Make cond always use dynamo to inspect udfs. #105679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| return pytree.tree_map(proxy_to_python_var, list(lifted_proxies.keys())) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the InlinedClosureVariable still a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't checked the test that triggers the InlinedClosureVariable not found error. Will check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Submit a separate pr in #106491.
| return cond(args[0], *bind_branch_and_args(gm, pos_args)) | ||
|
|
||
|
|
||
| def cond_compiled(pred, true_fn, false_fn, args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while attempting to test this locally, I have found the following:
- cond's args must be a list apparently. We should make it so that it can instead be a tuple before it goes public
- cond doesn't handle pytree inputs yet. seems OK.
- Is cond limited to a single output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- yeah, makes sense! 2. agree 3. yeah, currently it's single output.
functorch/experimental/_cond.py
Outdated
| return sym_int | ||
| size = pytree.tree_map(to_hint, t.size()) | ||
| stride = pytree.tree_map(to_hint, t.stride()) | ||
| return torch.empty_strided(size, stride, requires_grad=t.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh, what happens if cond captures a variable that is not a concrete Tensor?
b7ea29f to
4f612f1
Compare
#106491) When inlining a function which loads a closure, its direct parent may not load that closure. So we cannot find the closure name in parent's symbolic locals. In this PR, we fix it by recursively searching the parent instruction translator stack to resolve the closure. **Background** When developing #105679, this corner case is triggered. A small repro is added in the test of this pr, where outer is loaded by deep2 but not by deep. ```python def test_inline_closure_not_loaded_by_parent(self): def outer(a): return a + 1 def indirect(x): return direct(x) def direct(x): def deep2(c): return outer(c) def deep(c): return deep2(c) return deep(x) x = torch.randn(3) eager = indirect(x) counter = CompileCounter() compiled = torch._dynamo.optimize(counter)(indirect)(x) ``` Running the test, we have the following error before the PR: ``` Traceback (most recent call last): File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6584, in test_inline_closure_not_loaded_by_parent compiled = torch._dynamo.optimize(counter)(indirect)(x) File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 321, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 481, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 130, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 362, in _convert_frame_assert return _compile( File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 194, in time_wrapper r = func(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in _compile raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in _compile out_code = transform_code_object(code, transform) File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object transformations(instructions, code_options) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 417, in transform tracer.run() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2067, in run super().run() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run and self.step() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step getattr(self, inst.opname)(inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION self.call_function(fn, args, {}) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function self.push(fn.call_function(self, args, kwargs)) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function return super().call_function(tx, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return( File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_ tracer.run() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run and self.step() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step getattr(self, inst.opname)(inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION self.call_function(fn, args, {}) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function self.push(fn.call_function(self, args, kwargs)) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return( File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_ tracer.run() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run and self.step() File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step getattr(self, inst.opname)(inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION self.call_function(fn, args, {}) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function self.push(fn.call_function(self, args, kwargs)) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return( File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2227, in inline_call_ sub_locals, closure_cells = func.bind_args(parent, args, kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 471, in bind_args result[name] = parent.symbolic_locals[name] torch._dynamo.exc.InternalTorchDynamoError: outer from user code: File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6570, in indirect return direct(x) File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6579, in direct return deep(x) File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6577, in deep return deep2(c) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True To execute this test, run the following from the base repo dir: python test/dynamo/test_misc.py -k test_inline_closure_not_loaded_by_parent This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------------------------------------------- frames [('total', 1)] inline_call [] ---------------------------------------------------------------------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------------------------------------------------------------------- [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping helper /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py [2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/yidi/local/pytorch/torch/_dynamo/eval_frame.py [2023-08-02 15:48:36,561] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569 TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569 def indirect(x): [2023-08-02 15:48:36,591] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (3,) [<DimDynamic.STATIC: 2>] [None] TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6570 return direct(x) [2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF direct [] [2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [UserFunctionVariable()] [2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [UserFunctionVariable(), TensorVariable()] [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572> TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6572 (inline depth: 1) def direct(x): TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6573 (inline depth: 1) def deep2(c): [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE outer [] [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [InlinedClosureVariable()] [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep2 at 0x7fbe4d3666b0, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6573> [TupleVariable()] [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep2 [TupleVariable(), ConstantVariable(code)] [2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)] [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_DEREF deep2 [NestedUserFunctionVariable()] TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 1) def deep(c): [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE deep2 [] [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [NewCellVariable()] [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> [TupleVariable()] [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep [TupleVariable(), ConstantVariable(code)] [2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)] [2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST deep [NestedUserFunctionVariable()] TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6579 (inline depth: 1) return deep(x) [2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST deep [] [2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [NestedUserFunctionVariable()] [2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()] [2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 2) def deep(c): TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6577 (inline depth: 2) return deep2(c) [2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF deep2 [] [2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST c [NestedUserFunctionVariable()] [2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()] [2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes [2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> [2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes [2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572> [2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes ``` Test Plan: add new test Pull Request resolved: #106491 Approved by: https://github.com/williamwen42, https://github.com/jansel, https://github.com/zou3519
torch/_higher_order_ops/cond.py
Outdated
| # fake the new_args with original args | ||
| local_scope = {"L": {**locals(), "new_args": args}, "G": globals()} | ||
| pos_args = [eval(name, {}, local_scope) for name in example_names] | ||
|
|
||
| # We need to extract true_gm and false_gm from export | ||
| # as export won't add sym bool | ||
| def bind_branch_and_args(gm, pos_args): | ||
| ph2orig = dict( | ||
| zip((ph for ph in gm.graph.nodes if ph.op == "placeholder"), pos_args) | ||
| ) | ||
| cond_node = next((n for n in gm.graph.nodes if n.target is cond), None) | ||
| assert cond_node | ||
| true_gm = getattr(gm, cond_node.args[1].name) | ||
| false_gm = getattr(gm, cond_node.args[2].name) | ||
| pos_args = [] | ||
| for arg_node in cond_node.args[3]: | ||
| if arg_node.op == "placeholder" and arg_node in ph2orig: | ||
| pos_args.append(ph2orig[arg_node]) | ||
| elif arg_node.op == "get_attr": | ||
| pos_args.append(getattr(gm, arg_node.target)) | ||
| else: | ||
| raise RuntimeError(f"Cannot bind to original argumentes for {arg_node}") | ||
| return true_gm, false_gm, tuple(pos_args) | ||
|
|
||
| return cond(args[0], *bind_branch_and_args(gm, pos_args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO(rzou): read through getattr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Open question if this is good or if we should use the "trace through NN modules config option"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trace through NN modules config option PR: #103676
This PR adds a **same_signature** flag to dynamo.export. **Motivation:** In #105679, we experimented on **using dynamo to inspect the UDFs** for cond in eager mode (without torch.compile). This helps us to normalize the inputs (e.g. lifting closure to inputs) and makes higher order operator more robust (e.g. forbid python side effects) and less error-prone in general. We decided to use dynamo.export (instead of torch.compile) to do the inspection (pointed out by @voznesenskym @zou3519): - We'd like a **whole-graph capture** for the UDF. - We'd like the dynamo inspection to be **stateless**. Using torch.compile would require resetting dynamo context before and after the inspection because the compile flags may be different from users' torch.compile. This will clear all dynamo cache. - We can still implement some **caching** based on the guards. However, this requires export to be able to handle the case where it cannot always rewrite signature: e.g. closure lifted as input. This PR makes the rewrite optional. **Implementation:** We just put all the code that are related to signature rewriting into a function called rewrite_signature and use a same_signature flag to optionally to the transformation. **Test Plan:** existing tests. Pull Request resolved: #106569 Approved by: https://github.com/ezyang
494b701 to
4311861
Compare
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. Pull Request resolved: #107042 Approved by: https://github.com/ezyang, https://github.com/zou3519
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish #105679. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
| # instrumentation will see the meta conversions and the | ||
| # tests all break so we just exclude this. In any case | ||
| # the to conversion isn't really right anyhow. | ||
| ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code probably shouldn't run for all of the cases listed above? (quantized/nested/sparse, etc)
Currently, we have the assertion that dynamo won't accept FakeTensor input unless we're exporting. This PR try to remove this restriction to finish pytorch#105679. Pull Request resolved: pytorch#107042 Approved by: https://github.com/ezyang, https://github.com/zou3519
…ch#106569) This PR adds a **same_signature** flag to dynamo.export. **Motivation:** In pytorch#105679, we experimented on **using dynamo to inspect the UDFs** for cond in eager mode (without torch.compile). This helps us to normalize the inputs (e.g. lifting closure to inputs) and makes higher order operator more robust (e.g. forbid python side effects) and less error-prone in general. We decided to use dynamo.export (instead of torch.compile) to do the inspection (pointed out by @voznesenskym @zou3519): - We'd like a **whole-graph capture** for the UDF. - We'd like the dynamo inspection to be **stateless**. Using torch.compile would require resetting dynamo context before and after the inspection because the compile flags may be different from users' torch.compile. This will clear all dynamo cache. - We can still implement some **caching** based on the guards. However, this requires export to be able to handle the case where it cannot always rewrite signature: e.g. closure lifted as input. This PR makes the rewrite optional. **Implementation:** We just put all the code that are related to signature rewriting into a function called rewrite_signature and use a same_signature flag to optionally to the transformation. **Test Plan:** existing tests. Pull Request resolved: pytorch#106569 Approved by: https://github.com/ezyang
This PR adds a python entry point for cond: which will always use dynamo to inspect udfs i.e. true_fn, false_fn.
Implementation-wise:
The overall plan is:
We make sevaral workarounds to make the above plan work: turn off all dispatch keys that currently don't work with torch.compile: 1. make_fx(torch.compile) 2. functorch related such as: functionalize(torch.compile)
Test Plan:
See added and/or modified tests.
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov