Skip to content

Conversation

@guilhermeleobas
Copy link
Collaborator

@guilhermeleobas guilhermeleobas commented Aug 13, 2024

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133337

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e72efb8 with merge base 3b0f393 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@guilhermeleobas
Copy link
Collaborator Author

There's one failure when PYTORCH_TEST_WITH_DYNAMO is on:

test/functorch/test_aotdispatch.py::TestAOTAutograd::test_input_mutation_false_aliasing W0813 15:14:01.417000 53641 torch/fx/experimental/symbolic_shapes.py:4179] [12/5_1] Failing guard allocated at:
W0813 15:14:01.417000 53641 torch/fx/experimental/symbolic_shapes.py:4179] [12/5_1]
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Error while creating guard:
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Name: ''
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Source: shape_env
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Create Function: SHAPE_ENV
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Guard Types: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Code List: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Object Weakref: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     Guarded Class Weakref: None
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] Traceback (most recent call last):
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_guards.py", line 280, in create
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     return self.create_fn(builder, self)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/guards.py", line 1781, in SHAPE_ENV
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     guards = output_graph.shape_env.produce_guards(
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4188, in produce_guards
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     issue_guard(guard)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4153, in issue_guard
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 292, in doprint
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     return self._str(self._print(expr))
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 331, in _print
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     return printmethod(expr, **kwargs)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/str.py", line 776, in _print_Relational
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     self._print(expr.rhs))
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/micromamba/envs/pytorch-cuda/lib/python3.10/site-packages/sympy-1.13.1-py3.10.egg/sympy/printing/printer.py", line 331, in _print
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     return printmethod(expr, **kwargs)
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1588, in _print_Symbol
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1]     assert self.symbol_to_source.get(expr), (
E0813 15:14:01.417000 53641 torch/_guards.py:282] [12/5_1] AssertionError: s0 (could be from ["L['t'].size()[0]"]) not in {s2: ["L['t'].a.size()[0]", "L['t'].b.size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"], s3: ["L['t'].a.size()[1]", "L['t'].a.stride()[0]", "L['t'].b.size()[1]", "L['t'].b.stride()[0]", "L['t'].a.size()[1]", "L['t'].a.stride()[0]", "L['t'].b.size()[1]", "L['t'].b.stride()[0]"], s6: ["L['t'].size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1] Created at:
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/convert_frame.py", line 603, in transform
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]     tracer = InstructionTranslator(
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/symbolic_convert.py", line 2631, in __init__
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]     output=OutputGraph(
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/output_graph.py", line 316, in __init__
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]     self.init_ambient_guards()
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]   File "/home/guilhermeleobas/git/pytorch/torch/_dynamo/output_graph.py", line 455, in init_ambient_guards
E0813 15:14:01.418000 53641 torch/_guards.py:284] [12/5_1]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
SKIPPED [0.7462s] (s0 (could be from ["L['t'].size()[0]"]) not in {s2: ["L['t'].a.size()[0]", "L['t'].b.size()[0]", "L['t'].a.size()[0]", "L['t'].b.size()[0]"], s3...)

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 13, 2024

It would be nice to understand why this that test is failing (although running that test under dynamo is probably not common user behavior, since it adds a bunch of graph breaks due to testing framework code). But np if you don't want to go down that rabbit hole, happy to stamp.

Can we make the xskip into an xfail instead? Or does the test fail flakily (that would be more concerning)

[ghstack-poisoned]
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 13, 2024
@albanD albanD requested a review from bdhirsh August 13, 2024 20:50
@ezyang
Copy link
Contributor

ezyang commented Aug 15, 2024

The error is means t isn't a tracked fake. How exactly does Dynamo decide what tensors to track when it gets a subclass? That's the bug. And yeah, we should fix this before landing this.

@guilhermeleobas
Copy link
Collaborator Author

One of the guards refers to a symbol (s0) which doesn't exist in the list of placeholders.

(Pdb++) placeholders
[FakeTensor(..., size=(s2, s3)), FakeTensor(..., size=(s2, s3)), TwoTensor(FakeTensor(..., size=(s6, 4)), FakeTensor(..., size=(s6, 4))), FakeTensor(..., size=(s2, s3)), FakeTensor(..., size=(s2, s3))]

(Pdb++) self.guards[4]
ShapeGuard(expr=Eq(s1, 4) & Eq(s6, s0), stack=<torch.utils._traceback.CapturedTraceback object at 0x7f22dff07ac0>)

@ezyang
Copy link
Contributor

ezyang commented Aug 21, 2024

I see TwoTensor in the placeholders; naively, I was expecting its outer size to be s0? Is this wrong?

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 21, 2024

Stared at this some more, this test is enough to repro:

import torch
from torch.testing._internal.two_tensor import TwoTensor

@torch.compile(backend="eager", dynamic=True)
def f(t):
    tmp = t._base if t._is_view() else t
    return tmp + 1


x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])

It looks like:

(1) x is a differentiable view, so we go down the view_func that constructs a SymInt with "ephemeral source" path in meta_utils (view_from_base, link)

(2) when we first generate sizes when fakeifying x (here), we allocate and give fake_x a first dim of s0

(3) we then fakeify x._base (with static sizes), and then run view_from_base to regenerate our fake_x as a differentiable view off of the fake base

(4) in doing so, we have a symint_visitor_fn that creates SymInts with EphemeralSources for every int arg that our view funcs close over (here)

(5) In the repro I put above, we the view func we replay is the [3] slice, and 3 gets closed over and turned into a SymInt with an EphemeralSource, s2

(6) we then have an assertion that s0 == s2 here.

IIRC, when @jbschlosser added it the goal was to use the equality to "remove" the epherally-created symint, so all we have left is the original symints that we used to construct the view (s0 in this case). But that doesn't really end up happening here: the result of result of fake_t = t.view_func(...) (code) is exactly what we return back to dynamo, and this FakeTensor has the s2 symint in its shape (because... comparing s0 == s2 won't magically swap out our tensor's shape).

So we end up returning a FakeTensor back to dynamo with an s2 symint on it (ephemeral source), and we end up with an s0 == s2 guard, where dynamo never had a chance to track s0.

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 21, 2024

The easiest/dumbest solution to me is to slam-in-and-replace s2 with s0 in the final FakeTensor we return, since we've statically proven that they are equal, and s0 is the SymInt with the "non-epheral-source" that we want dynamo to track.

We would need some kind of "unsafe swap symints" API on the TensorImpl though, which feels bad. cc @ezyang if you think there are any better alternatives?

@ezyang
Copy link
Contributor

ezyang commented Aug 21, 2024

I don't understand. Wasn't the point of ephemeral sources that when we perform equality on them, we preferentially eliminate them, so that there should never be ephemeral sources left? You can "eliminate" a symbol by ensuring it has a replacement to another (in this case, s2 to s0).

This is what the is_ephemeral logic is about:

torch/fx/experimental/symbolic_shapes.py:            if not source.is_ephemeral() and r_sources[0].is_ephemeral():
torch/fx/experimental/symbolic_shapes.py:                x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])

So if it is not working, it seems more direct to try to fix it.

Or if we want to rip up the sidewalks, figure out how to get rid of ephemeral sources entirely.

@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 23, 2024

I agree with the assessment above, although I'm not sure if in practice we should try to fix the EphemeralSource logic or "rip up the sidewalk".

cc @jbschlosser in case you have any ideas? (I'm on PTO next week but I can think about it more when I get back too)

@guilhermeleobas
Copy link
Collaborator Author

Given this problem is reproducible without the changes introduced here. Could I merge this PR and work on a fix for this problem on a separate PR?

@jbschlosser
Copy link
Contributor

cc @jbschlosser in case you have any ideas? (I'm on PTO next week but I can think about it more when I get back too)

I believe #128649 is related. That issue mentions view(-1) but PR #128659 that uncovered that issue mentions slicing being problematic as well. I've done a little bit of investigation into why the ephemerally-sourced symbols aren't being simplified out but this needs more work.

@bdhirsh / @guilhermeleobas How high-pri is this - are there real use cases hitting this problem?

@guilhermeleobas
Copy link
Collaborator Author

@bdhirsh / @guilhermeleobas How high-pri is this - are there real use cases hitting this problem?

No real use case as far as I know. This bug doesn't affect my work that much.

[ghstack-poisoned]
@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 4, 2024

Ok, I am ok with landing this without the EphemeralSource fix and following up later.

Instead of adding that test failure file, can you just add something like this directly into the test?

if torch.compiler.is_compiling():
    # TODO: fails due to existing bugs with EphemeralSource handling
    # See https://github.com/pytorch/pytorch/pull/133337#issuecomment-2312769762
    return

(and also maybe comment on #128649 with a link to that the test change so we know to remove the skip when we look into that issue more)

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, it's an xfail so it should error / flag that we need to update it once someone properly figures out the EphemeralSource situation (I do think this is hi-pri though)

[ghstack-poisoned]
[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 11, 2024
…views"


**TL;DR:** This PR does the following hacks:
* Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`.
* Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`.

**Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support.

In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first.

**Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649.
```
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}.  If this assert is failing, it could be due to the issue described in #90665
```

**Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case:
```python
torch.compile(backend="eager", dynamic=True)
def f(t):
    return t._base + 1

x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])
```
The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following:
1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset.
2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset.
3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`.
https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970

This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`).

To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`.

In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 17, 2024
…views"


**TL;DR:** This PR does the following hacks:
* Allows SymInt replacements for expressions with components that are integers, reciprocals (e.g. `1/s0`), and multiplications of symbols with reciprocals (e.g. `s1/s0`). Previously, only integers were [allowed](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4970). For example, when we find `Eq(s2, s1/s0)`, we can now replace `s2` with `s1/s0`.
* Approximates the value range for `1/s0` with positive integer `s0` as `[0, 1]`.

**Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support.

In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first.

**Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649.
```
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}.  If this assert is failing, it could be due to the issue described in #90665
```

**Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses the following case:
```python
torch.compile(backend="eager", dynamic=True)
def f(t):
    return t._base + 1

x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])
```
The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following:
1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset.
2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset.
3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `s1/s0`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. Note that the `try_solve()` call below is correctly able to find `s2=s1/s0`.
https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970

This PR hacks around this, expanding the supported set of replacements to include both reciprocals (e.g. `1/s0`, represented as `Pow(s0, -1)`) and multiplications of symbols with reciprocals (e.g. `s1/s0`, which is represented as `Mul(s1, Pow(s0, -1))`).

To make a replacement, it is also necessary to determine a new value range for the expression, and the new value range must be a subset of the existing value range for the symbol to be replaced. Before this PR, the logic did not know how to calculate a new value range for a reciprocal `Pow(s0, -1)`, only supporting positive exponents. This PR hacks in a rough loose bound for the expected value range of `Pow(s0, -1)` AKA `1/s0` given the value range of `s0` for positive integer `s0`. To keep value ranges integral so as to operate well with other integral value ranges, the value range for `1/s0` is approximated as `[0, 1]`.

In an expression like `s1/s0`, the lower / upper bounds for the ranges of `s1` and `1/s0` are multiplied to determine the final value range for the full expression. For example, given a range of `[0, int_oo]` for `s1` and the approximated `[0, 1]` range for `1/s0`, the final range for `s1/s0` is also found to be `[0, int_oo]`.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Sep 23, 2024
**TL;DR:** This PR does the following:
* Uses `CleanDiv` instead of `FloorDiv` in more places when we know it holds to get strictly more information:
    * In `_try_isolate_lhs()`, if we know that `lhs` and `rhs` are integers, we can solve via `CleanDiv` instead of `/`
        * The PR teaches `_try_isolate_lhs()` how to handle `CleanDiv(numer, denom)` where the `denom` component contains the thing we're trying to isolate. This happens in practice for `slice()` and `view()` views
    * In the `infer_size()` util used by various views (including `view()` and `reshape()`), use `CleanDiv` for inferring the `-1` dimension from the `total size / product of other dims`
        * We also learn that all components of the product are >= 1, so provide that information as well

**Background:** During subclass fake-ification, subclass instances that are views go through the process of "view replay" to reconstruct the base -> view relationship in fake land. Part of view replay involves symbolicizing any closed-over ints in the original view function. For example, a slice view `x[3]` will have a `3` baked in as an arg to slice in the saved view func; we want to symbolicize this 3 for dynamic shapes support.

In practice, before the view is replayed, all closed-over ints are replaced with SymInts with "ephemeral sources" that are intended to disappear when we [assert](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) the subclass size / stride / storage offset match the allocated, non-ephemeral outer symbolic values. For this to work, we rely on the [SymInt replacement logic](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4740-L4744) (e.g. replace `s1` with `s0 * 5` if `Eq(s1, s0*5)` is found to be true), and [prioritize ephemerally-sourced SymInts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4909-L4921) to be replaced first.

**Problem:** In certain cases, the SymInt replacement logic does not apply, and an ephemerally-sourced SymInt lives for longer than originally designed, resulting in guard time errors like those reported [here](#133337 (comment)) and in #128649.
```
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}.  If this assert is failing, it could be due to the issue described in #90665
```

**Solution:** For each of the error cases, we can identify why a particular SymInt replacement is not made and enhance the logic to support such replacements. This PR addresses this for `slice()` and `view()` views that have been encountered in practice. For example:
```python
torch.compile(backend="eager", dynamic=True)
def f(t):
    return t._base + 1

x_a = torch.randn(4, 4, requires_grad=True)
x = TwoTensor(x_a, x_a.clone())
out = f(x[3])
```
The input to the compiled func is a subclass view produced by `slice()`. Fake-ification of this subclass view does the following:
1. Allocates non-ephemeral size / stride / storage offset symbols for the outer subclass metadata. Since the slice is contiguous with shape `(4)`, we get `(s0)` for the shape, `(1)` for the stride, and `s1` for the storage offset.
2. Perform view replay on a fake-ified base to reconstruct the base -> view relationship. The captured `3` arg to `slice()` is replaced with an ephemerally-sourced SymInt `s2` before the view func is replayed. The output of view replay has shape `(s0)`, stride `(1)`, and `s0 * s2` for the storage offset.
3. [Asserts](https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/_subclasses/meta_utils.py#L1073-L1078) that the output of view replay has the shape metadata allocated in step (1). Since we have `Eq(s1, s0 * s2)` from the storage offset assert and `s2` is ephemeral, we should replace `s2` with `CleanDiv(s1, s0)`. Before this PR, this replacement was not supported since `s1/s0` is not guaranteed to be an integer. With `CleanDiv`, we have the integer guarantee and the replacement can happen.
https://github.com/pytorch/pytorch/blob/49e0b88aab0fd124c318fe1c59fbbd8726298338/torch/fx/experimental/symbolic_shapes.py#L4969-L4970

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
[ghstack-poisoned]
guilhermeleobas and others added 5 commits October 17, 2024 16:03
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@github-actions github-actions bot deleted the gh/guilhermeleobas/60/head branch November 28, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamo-tensor-subclasses Merged open source tensor subclass Related to tensor subclasses topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants