Skip to content

Conversation

@pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Sep 26, 2024

Summary:
fix for #136797

We have an issue with deserialization for symbolic expressions, where the original expression has a structure mismatch with the deserialized expression. This happens because we serialize the expressions into strings, and deserialize using sympy.sympify().

This seems to cause issues with custom torch sympy.Functions, in this case FloorDiv. In this case, the expression FloorDiv(a, b) turns into sympy.floor(a, PowByNatural(b, -1)), or sympy.floor(sympy.Mul(*a), PowByNatural(b, -1)) if a is also a sympy.Mul. This seems to break downstream when we call torch.empty_strided() on this expression, and the PowByNatural has an invalid ValueRange (e.g. [0, -1]).

This fixes that by adding translations after deserialization, to pattern match and transform expressions into the original format. Other potential strategies are a) we could also serialize SymInts as sympy.Expr structures, but that seems a lot more involved and BC breaking, or b) we could modify sympy.sympify(?), but I don't know if this is possible.

Test Plan: test_export/test_serialize

Differential Revision: D63493615

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136802

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit b53cc2d with merge base 46f158b (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63493615

facebook-github-bot pushed a commit that referenced this pull request Sep 26, 2024
…136802)

Summary:

fix for #136797

We have an issue with deserialization for symbolic expressions, where the original expression has a structure mismatch with the deserialized expression. This happens because we serialize the expressions into strings, and deserialize using sympy.sympify().

This seems to cause issues with custom torch sympy.Functions, in this case FloorDiv. In this case, this alters the structure from `FloorDiv(a, b)` into `sympy.floor(a, PowByNatural(b, -1))`, or `sympy.floor(sympy.Mul(*a), PowByNatural(b, -1))` if a is also a sympy.Mul. This seems to break downstream when we call torch.empty_strided() on this expression, and the PowByNatural has an invalid ValueRange (e.g. [0, -1]).

This fixes that by adding translations after deserialization, to pattern match and transform expressions into the original format. Other potential strategies are a) we could also serialize SymInts as sympy.Expr structures, but that seems a lot more involved and BC breaking, or b) we could modify sympy.sympify(?), but I don't know if this is possible.

Test Plan: test_export/test_serialize

Differential Revision: D63493615
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63493615

@pianpwk pianpwk changed the title [export] translation for FloorDiv deserialization [export] add translations for SymInt/Bool deserialization; FloorDiv Sep 26, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63493615

@pianpwk
Copy link
Contributor Author

pianpwk commented Sep 27, 2024

@ezyang would you know of a better way to do this? It seems like even if we let the invalid ValueRange pass or flip min/max, PowByNatural + ValueRanges is the thing to avoid, because it still means we lose precision when truncating the range to ints (e.g. [0, -1]).

compiler_max=vr.upper, # type: ignore[arg-type]
)

sym = _translate_deserialized_sym_expr(sym)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this line go directly after line 1526?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2024

So, the proper way to fix this is to stop serializing the Sympy string expressions as strings, for which we have NO stability or even round-trippability guarantee. Sympy expressions are supposed to be roundtrippable to FX IR, and back, via the Python reference / sympy interpreter stuff (there are some bugs / incompleteness, but we should fix those anyway.)

But I'm guessing you didn't want to hear that? Lol.

@pianpwk
Copy link
Contributor Author

pianpwk commented Sep 30, 2024

So, the proper way to fix this is to stop serializing the Sympy string expressions as strings, for which we have NO stability or even round-trippability guarantee. Sympy expressions are supposed to be roundtrippable to FX IR, and back, via the Python reference / sympy interpreter stuff (there are some bugs / incompleteness, but we should fix those anyway.)

But I'm guessing you didn't want to hear that? Lol.

Hahaha no but this makes total sense

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2024

If you are dead set on string format, I'd probably still want to have a customized printer and parser. In particular the parser needs to override the mapping of FloorDiv to which class. As a stopgap, we could potentially rename functions in torch.utils._sympy.functions so they don't conflict with Sympy

@bhack
Copy link
Contributor

bhack commented Oct 1, 2024

I think we had the same problem with inductor cache right?

site-packages/torch/utils/_sympy/interp.py:159] [0/1] failed while executing pow_by_natural([VR[0, int_oo], VR[-1, -1]])
site-packages/torch/_inductor/codecache.py:609] [0/1] Can't pickle
site-packages/torch/_inductor/codecache.py:609] [0/1] Traceback (most recent call last):
site-packages/torch/_inductor/codecache.py:609] [0/1]   File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 605, in dumps
site-packages/torch/_inductor/codecache.py:609] [0/1]     pickler.dump(obj)
site-packages/torch/_inductor/codecache.py:609] [0/1] AttributeError: Can't pickle local object 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
site-packages/torch/utils/_sympy/interp.py:159] [0/1] failed while executing pow_by_natural([VR[0, int_oo], VR[-1, -1]])

…136802)

Summary:

fix for #136797

We have an issue with deserialization for symbolic expressions, where the original expression has a structure mismatch with the deserialized expression. This happens because we serialize the expressions into strings, and deserialize using sympy.sympify().

This seems to cause issues with custom torch sympy.Functions, in this case FloorDiv. In this case, this alters the structure from `FloorDiv(a, b)` into `sympy.floor(a, PowByNatural(b, -1))`, or `sympy.floor(sympy.Mul(*a), PowByNatural(b, -1))` if a is also a sympy.Mul. This seems to break downstream when we call torch.empty_strided() on this expression, and the PowByNatural has an invalid ValueRange (e.g. [0, -1]).

This fixes that by adding translations after deserialization, to pattern match and transform expressions into the original format. Other potential strategies are a) we could also serialize SymInts as sympy.Expr structures, but that seems a lot more involved and BC breaking, or b) we could modify sympy.sympify(?), but I don't know if this is possible.

Test Plan: test_export/test_serialize

Differential Revision: D63493615
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63493615

@pianpwk
Copy link
Contributor Author

pianpwk commented Oct 3, 2024

Yeah, I think we should just move to not string format, and add to the export serialization schema

@bhack
Copy link
Contributor

bhack commented Oct 4, 2024

#136802 (comment)

Is this related right?

@bhack
Copy link
Contributor

bhack commented Oct 10, 2024

Do I need to open a new ticket for #136802 (comment) or we are on the same page here?

@bhack
Copy link
Contributor

bhack commented Oct 23, 2024

Is there a roadmap to review this?

@bhack
Copy link
Contributor

bhack commented Nov 5, 2024

Is this stalled?

@bhack
Copy link
Contributor

bhack commented Nov 6, 2024

I've tried to hotpatch the latest nightly with this PR and the deserializzation issue is still here.

facebook-github-bot pushed a commit that referenced this pull request Nov 22, 2024
Summary:

Latest attempt after [136802](#136802) and [140084](#140084) got shelved. 

This keeps the string format for `expr_str`, but calls `sympy.printing.repr.srepr(s)` instead of `str(s)`, which prints expressions more explicitly, e.g.
```
((2*x)//(3*y + 4)) -> "FloorDiv(Mul(Integer(2), Symbol('x')), Add(Mul(Integer(3), Symbol('y')), Integer(4)))"
```

This is nice because:
- we have better roundtrippability for deserialization, robust to pretty printing changes like [this](https://github.com/pytorch/pytorch/blob/6c9bfd52b6a76ddff053bcff4d23ea7f4c280e9a/torch/utils/_sympy/functions.py#L208) that caused the issue in the first place.
- this preserves the BC surface for both 1) sigmoid thrift serialization, by keeping the string format, and 2) deserialization for old IRs, since `sympy.sympify(...)` still handles the old `str(s)` format.
- more memory efficient than storing ASTs; the [AST attempt](#140084) increased artifact size by 20% on some toy programs.
- doesn't even require a schema version bump.

Additionally to push some test cases over the line, this redoes expression processing (handling ranges, symbol caching) by doing bottom-up processing instead of the current hacky-ish workflow.

Test Plan: test_serdes, test_serialize, internal tests broken by AST PR

Reviewed By: zhxchen17

Differential Revision: D66283208
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2024
Summary:
Latest attempt after [136802](#136802) and [140084](#140084) got shelved.

This keeps the string format for `expr_str`, but calls `sympy.printing.repr.srepr(s)` instead of `str(s)`, which prints expressions more explicitly, e.g.
```
((2*x)//(3*y + 4)) -> "FloorDiv(Mul(Integer(2), Symbol('x')), Add(Mul(Integer(3), Symbol('y')), Integer(4)))"
```

This is nice because:
- we have better roundtrippability for deserialization, robust to pretty printing changes like [this](https://github.com/pytorch/pytorch/blob/6c9bfd52b6a76ddff053bcff4d23ea7f4c280e9a/torch/utils/_sympy/functions.py#L208) that caused the issue in the first place.
- this preserves the BC surface for both 1) sigmoid thrift serialization, by keeping the string format, and 2) deserialization for old IRs, since `sympy.sympify(...)` still handles the old `str(s)` format.
- more memory efficient than storing ASTs; the [AST attempt](#140084) increased artifact size by 20% on some toy programs.
- doesn't even require a schema version bump.

Additionally to push some test cases over the line, this redoes expression processing (handling ranges, symbol caching) by doing bottom-up processing instead of the current hacky-ish workflow.

Test Plan: test_serdes, test_serialize, internal tests broken by AST PR

Differential Revision: D66283208

Pull Request resolved: #141284
Approved by: https://github.com/zhxchen17
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
…141284)

Summary:
Latest attempt after [136802](pytorch#136802) and [140084](pytorch#140084) got shelved.

This keeps the string format for `expr_str`, but calls `sympy.printing.repr.srepr(s)` instead of `str(s)`, which prints expressions more explicitly, e.g.
```
((2*x)//(3*y + 4)) -> "FloorDiv(Mul(Integer(2), Symbol('x')), Add(Mul(Integer(3), Symbol('y')), Integer(4)))"
```

This is nice because:
- we have better roundtrippability for deserialization, robust to pretty printing changes like [this](https://github.com/pytorch/pytorch/blob/6c9bfd52b6a76ddff053bcff4d23ea7f4c280e9a/torch/utils/_sympy/functions.py#L208) that caused the issue in the first place.
- this preserves the BC surface for both 1) sigmoid thrift serialization, by keeping the string format, and 2) deserialization for old IRs, since `sympy.sympify(...)` still handles the old `str(s)` format.
- more memory efficient than storing ASTs; the [AST attempt](pytorch#140084) increased artifact size by 20% on some toy programs.
- doesn't even require a schema version bump.

Additionally to push some test cases over the line, this redoes expression processing (handling ranges, symbol caching) by doing bottom-up processing instead of the current hacky-ish workflow.

Test Plan: test_serdes, test_serialize, internal tests broken by AST PR

Differential Revision: D66283208

Pull Request resolved: pytorch#141284
Approved by: https://github.com/zhxchen17
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…141284)

Summary:
Latest attempt after [136802](pytorch#136802) and [140084](pytorch#140084) got shelved.

This keeps the string format for `expr_str`, but calls `sympy.printing.repr.srepr(s)` instead of `str(s)`, which prints expressions more explicitly, e.g.
```
((2*x)//(3*y + 4)) -> "FloorDiv(Mul(Integer(2), Symbol('x')), Add(Mul(Integer(3), Symbol('y')), Integer(4)))"
```

This is nice because:
- we have better roundtrippability for deserialization, robust to pretty printing changes like [this](https://github.com/pytorch/pytorch/blob/6c9bfd52b6a76ddff053bcff4d23ea7f4c280e9a/torch/utils/_sympy/functions.py#L208) that caused the issue in the first place.
- this preserves the BC surface for both 1) sigmoid thrift serialization, by keeping the string format, and 2) deserialization for old IRs, since `sympy.sympify(...)` still handles the old `str(s)` format.
- more memory efficient than storing ASTs; the [AST attempt](pytorch#140084) increased artifact size by 20% on some toy programs.
- doesn't even require a schema version bump.

Additionally to push some test cases over the line, this redoes expression processing (handling ranges, symbol caching) by doing bottom-up processing instead of the current hacky-ish workflow.

Test Plan: test_serdes, test_serialize, internal tests broken by AST PR

Differential Revision: D66283208

Pull Request resolved: pytorch#141284
Approved by: https://github.com/zhxchen17
@github-actions
Copy link
Contributor

github-actions bot commented Jan 5, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 5, 2025
@bhack
Copy link
Contributor

bhack commented Jan 5, 2025

if this is going to be close as stale do we still have a ticket to track this?

@avikchaudhuri
Copy link
Contributor

Was this fixed by #141284?

@bhack
Copy link
Contributor

bhack commented Jan 6, 2025

I've tested on Nov, 2 and it was merge on Nov, 22. So probably it is ok

@github-actions github-actions bot closed this Feb 5, 2025
@github-actions github-actions bot deleted the export-D63493615 branch March 8, 2025 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants