-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update TwoTensor impl. to accept outer_size/outer_stride
#133337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update TwoTensor impl. to accept outer_size/outer_stride
#133337
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133337
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e72efb8 with merge base 3b0f393 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
There's one failure when test/functorch/test_aotdispatch.py::TestAOTAutograd::test_input_mutation_false_aliasing W0813 15:14:01.417000 53641 torch/fx/experimental/symbolic_shapes.py:4179] [12/5_1] Failing guard allocated at:
W0813 15:14:01.417000 53641 torch/fx/experimental/symbolic_shapes.py:4179] [12/5_1]
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Error while creating guard:
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Name: ''
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Source: shape_env
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Create Function: SHAPE_ENV
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Guard Types: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Code List: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Object Weakref: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Guarded Class Weakref: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Traceback (most recent call last):
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_guards.py", line 280, in create
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] return self.create_fn(builder, self)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/guards.py", line 1781, in SHAPE_ENV
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] guards = output_graph.shape_env.produce_guards(
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4188, in produce_guards
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] issue_guard(guard)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4153, in issue_guard
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 292, in doprint
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] return self._str(self._print(expr))
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 331, in _print
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] return printmethod(expr, **kwargs)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/str.py", line 776, in _print_Relational
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] self._print(expr.rhs))
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 331, in _print
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] return printmethod(expr, **kwargs)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1588, in _print_Symbol
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] assert self.symbol_to_source.get(expr), (
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] AssertionError: s0 (could be from ["L['t'].size()[0]"]) not in {s2: ["L['t'].a.size()[0]", "L['t'].b.size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"], s3: ["L['t'].a.size()[1]", "L['t'].a.stride()[0]", "L['t'].b.size()[1]", "L['t'].b.stride()[0]", "L['t'].a.size()[1]", "L['t'].a.stride()[0]", "L['t'].b.size()[1]", "L['t'].b.stride()[0]"], s6: ["L['t'].size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] Created at:
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/convert_frame.py", line 603, in transform
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] tracer = InstructionTranslator(
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/symbolic_convert.py", line 2631, in __init__
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] output=OutputGraph(
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/output_graph.py", line 316, in __init__
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] self.init_ambient_guards()
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/output_graph.py", line 455, in init_ambient_guards
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
SKIPPED [0.7462s] (s0 (could be from ["L['t'].size()[0]"]) not in {s2: ["L['t'].a.size()[0]", "L['t'].b.size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"], s3...) |
|
It would be nice to understand why this that test is failing (although running that test under dynamo is probably not common user behavior, since it adds a bunch of graph breaks due to testing framework code). But np if you don't want to go down that rabbit hole, happy to stamp. Can we make the xskip into an xfail instead? Or does the test fail flakily (that would be more concerning) |
|
The error is means t isn't a tracked fake. How exactly does Dynamo decide what tensors to track when it gets a subclass? That's the bug. And yeah, we should fix this before landing this. |
|
One of the guards refers to a symbol ( (Pdb++) placeholders
[FakeTensor(..., size=(s2, s3)), FakeTensor(..., size=(s2, s3)), TwoTensor(FakeTensor(..., size=(s6, 4)), FakeTensor(..., size=(s6, 4))), FakeTensor(..., size=(s2, s3)), FakeTensor(..., size=(s2, s3))]
(Pdb++) self.guards[4]
ShapeGuard(expr=Eq(s1, 4) & Eq(s6, s0), stack=<torch.utils._traceback.CapturedTraceback object at 0x7f22dff07ac0>) |
|
I see TwoTensor in the placeholders; naively, I was expecting its outer size to be s0? Is this wrong? |
|
Stared at this some more, this test is enough to repro: It looks like: (1) (2) when we first generate sizes when fakeifying (3) we then fakeify (4) in doing so, we have a (5) In the repro I put above, we the view func we replay is the (6) we then have an assertion that IIRC, when @jbschlosser added it the goal was to use the equality to "remove" the epherally-created symint, so all we have left is the original symints that we used to construct the view ( So we end up returning a FakeTensor back to dynamo with an |
|
The easiest/dumbest solution to me is to slam-in-and-replace We would need some kind of "unsafe swap symints" API on the TensorImpl though, which feels bad. cc @ezyang if you think there are any better alternatives? |
|
I don't understand. Wasn't the point of ephemeral sources that when we perform equality on them, we preferentially eliminate them, so that there should never be ephemeral sources left? You can "eliminate" a symbol by ensuring it has a replacement to another (in this case, s2 to s0). This is what the is_ephemeral logic is about: So if it is not working, it seems more direct to try to fix it. Or if we want to rip up the sidewalks, figure out how to get rid of ephemeral sources entirely. |
|
I agree with the assessment above, although I'm not sure if in practice we should try to fix the EphemeralSource logic or "rip up the sidewalk". cc @jbschlosser in case you have any ideas? (I'm on PTO next week but I can think about it more when I get back too) |
|
Given this problem is reproducible without the changes introduced here. Could I merge this PR and work on a fix for this problem on a separate PR? |
I believe #128649 is related. That issue mentions @bdhirsh / @guilhermeleobas How high-pri is this - are there real use cases hitting this problem? |
No real use case as far as I know. This bug doesn't affect my work that much. |
|
Ok, I am ok with landing this without the EphemeralSource fix and following up later. Instead of adding that test failure file, can you just add something like this directly into the test? (and also maybe comment on #128649 with a link to that the test change so we know to remove the skip when we look into that issue more) |
bdhirsh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, it's an xfail so it should error / flag that we need to update it once someone properly figures out the EphemeralSource situation (I do think this is hi-pri though)
…views" **TL;DR:** This PR does the following hacks: * Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`. * Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`. **Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support. In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first. **Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649. ``` AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665 ``` **Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case: ```python torch.compile(backend="eager", dynamic=True) def f(t): return t._base + 1 x_a = torch.randn(4, 4, requires_grad=True) x = TwoTensor(x_a, x_a.clone()) out = f(x[3]) ``` The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following: 1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset. 2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset. 3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`. https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970 This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`). To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`. In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
…views" **TL;DR:** This PR does the following hacks: * Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`. * Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`. **Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support. In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first. **Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649. ``` AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665 ``` **Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case: ```python torch.compile(backend="eager", dynamic=True) def f(t): return t._base + 1 x_a = torch.randn(4, 4, requires_grad=True) x = TwoTensor(x_a, x_a.clone()) out = f(x[3]) ``` The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following: 1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset. 2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset. 3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`. https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970 This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`). To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`. In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
**TL;DR:** This PR does the following:
* Uses `CleanDiv` instead of `FloorDiv` in more places when we know it holds to get strictly more information:
* In `_try_isolate_lhs()`, if we know that `lhs` and `rhs` are integers, we can solve via `CleanDiv` instead of `/`
* The PR teaches `_try_isolate_lhs()` how to handle `CleanDiv(numer, denom)` where the `denom` component contains the thing we're trying to isolate. This happens in practice for `slice()` and `view()` views
* In the `infer_size()` util used by various views (including `view()` and `reshape()`), use `CleanDiv` for inferring the `-1` dimension from the `total size / product of other dims`
* We also learn that all components of the product are >= 1, so provide that information as well
**Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support.
In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first.
**Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649.
```
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in #90665
```
**Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses this for `slice()` and `view()` views that have been encountered in practice. For example:
```python
torch.compile(backend="eager", dynamic=True)
def f(t):
return t._base + 1
x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])
```
The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following:
1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset.
2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset.
3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `CleanDiv(s1, s0)`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. With `CleanDiv`, we have the integer guarantee and the replacement can happen.
https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec
[ghstack-poisoned]
…125941) Pull Request resolved: #125941 Approved by: https://github.com/bdhirsh ghstack dependencies: #133337
…ytorch#125941) Pull Request resolved: pytorch#125941 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#133337
…ytorch#125941) Pull Request resolved: pytorch#125941 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#133337
Stack from ghstack (oldest at bottom):
outer_size/outer_stride#133337cc @ezyang @albanD