-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Preserve original GraphArgs for shape guard codegen #90665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90665
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 58b38cc: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
| # Can we actually demonstrate that checkpointing the ShapeEnv is | ||
| # necessary? It's not so easy to induce this case. Dynamo is very | ||
| # eager about adding locals to GraphArgs; any local that is in scope, | ||
| # even if it isn't used, is added to GraphArgs (see also | ||
| # https://github.com/pytorch/torchdynamo/issues/1925 ). So long | ||
| # as Dynamo eagerly guards in this way, we have an invariant that | ||
| # all locals are guaranteed to show up in GraphArgs before the | ||
| # inlining function call, in which case we will always have enough | ||
| # information to codegen our guards so long as we don't prune the | ||
| # unused GraphArgs away (and indeed, the direct fix for this bug | ||
| # was to make sure we use original GraphArgs). Non locals, | ||
| # conversely, typically are static, and so won't have guards allocated | ||
| # for them. That being said, there may still be a way to trigger | ||
| # this error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This deserves its own test to protect this invariant, we take it as a matter of course for now, but it changing subtly under the hood could cause really nasty to debug issues. Let's file an issue to test this (specifically, the exhaustive guarding of locals).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a variant. I just failed to come up with a test case that could actually trigger the more naughty fix.
The proper fix is to checkpoint shape env properly. But I'm hoping we can put that off a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out that you can trigger the edge case discussed here with unspecialized ints, and in fact quite easily.
Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: inductor ,inductor / cuda11.6-py3.10-gcc7-sm86 / test (inductor_timm, 1, 2, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "master failure only" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
I had an error on nightly mentioning this PR torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: s6 (could be from ["L['self'].last_size_2d[0]"]) not in {s1: ["L['t0']"], s10: ["L['t1']"], s7: ["L['t3']"], s6: []}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more informationWhat info do you need? |
|
Are you running the cache by any chance? If so, it's probably #127970 |
|
Yes it is with |
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.
Additionally, this exposes a bug that fails with the view
```
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
t = self._print(term)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in #90665
```
[ghstack-poisoned]
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.
Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple.
Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes:
```
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
t = self._print(term)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in #90665
```
cc cpuhrsch jbschlosser bhosmer drisspg soulitzer ezyang albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang
[ghstack-poisoned]
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.
Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple.
Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes:
```
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
t = self._print(term)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in #90665
```
cc cpuhrsch jbschlosser bhosmer drisspg soulitzer ezyang albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang
[ghstack-poisoned]
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.
Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple.
Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes:
```
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
t = self._print(term)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in #90665
```
cc cpuhrsch jbschlosser bhosmer drisspg soulitzer ezyang albanD voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang
[ghstack-poisoned]
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.
Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple.
Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes:
```
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
t = self._print(term)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in #90665
```
Pull Request resolved: #128659
Approved by: https://github.com/YuqingJ
) When we don't dynamo.reset(), we don't recompile on different dynamic shapes. Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple. Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes: ``` File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint return self._str(self._print(expr)) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print return printmethod(expr, **kwargs) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add t = self._print(term) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print return printmethod(expr, **kwargs) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul a_str = [self.parenthesize(x, prec, strict=False) for x in a] File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp> a_str = [self.parenthesize(x, prec, strict=False) for x in a] File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize return self._print(item) File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print return printmethod(expr, **kwargs) File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol assert self.symbol_to_source.get(expr), ( AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}. If this assert is failing, it could be due to the issue described in pytorch#90665 ``` Pull Request resolved: pytorch#128659 Approved by: https://github.com/YuqingJ
|
I'm commenting here for anyone else who lands here by the link in the exception. I hit this in pytorch The issue was that I was slicing the input tensor to my network that was torch.compiled, which caused the stride to mismatch the shape. I was doing this: Adding .contiguous() fixed this issue: I think there is probably a bug here though. |
|
Oh, for yours, it's possible nightly has fixed it. I think I need to update the error description here |
…views" **TL;DR:** This PR does the following hacks: * Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`. * Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`. **Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support. In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first. **Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649. ``` AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665 ``` **Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case: ```python torch.compile(backend="eager", dynamic=True) def f(t): return t._base + 1 x_a = torch.randn(4, 4, requires_grad=True) x = TwoTensor(x_a, x_a.clone()) out = f(x[3]) ``` The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following: 1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset. 2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset. 3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`. https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970 This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`). To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`. In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
…views" **TL;DR:** This PR does the following hacks: * Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`. * Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`. **Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support. In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first. **Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649. ``` AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665 ``` **Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case: ```python torch.compile(backend="eager", dynamic=True) def f(t): return t._base + 1 x_a = torch.randn(4, 4, requires_grad=True) x = TwoTensor(x_a, x_a.clone()) out = f(x[3]) ``` The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following: 1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset. 2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset. 3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`. https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970 This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`). To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`. In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
**TL;DR:** This PR does the following:
* Uses `CleanDiv` instead of `FloorDiv` in more places when we know it holds to get strictly more information:
* In `_try_isolate_lhs()`, if we know that `lhs` and `rhs` are integers, we can solve via `CleanDiv` instead of `/`
* The PR teaches `_try_isolate_lhs()` how to handle `CleanDiv(numer, denom)` where the `denom` component contains the thing we're trying to isolate. This happens in practice for `slice()` and `view()` views
* In the `infer_size()` util used by various views (including `view()` and `reshape()`), use `CleanDiv` for inferring the `-1` dimension from the `total size / product of other dims`
* We also learn that all components of the product are >= 1, so provide that information as well
**Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support.
In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first.
**Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649.
```
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665
```
**Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses this for `slice()` and `view()` views that have been encountered in practice. For example:
```python
torch.compile(backend="eager", dynamic=True)
def f(t):
return t._base + 1
x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])
```
The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following:
1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset.
2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset.
3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `CleanDiv(s1, s0)`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. With `CleanDiv`, we have the integer guarantee and the replacement can happen.
https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Signed-off-by: Edward Z. Yang [email protected]
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire