-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: dynamic shapesoncall: exportoncall: pt2
Description
🐛 Describe the bug
This issue occurs during export deserialization of symint values, when we have a sympy.Expr like a/b, this decomposes into a * PowByNatural(b, -1). The PowByNatural seems to lead to an invalid ValueRange where min > max for 1/b.
For example, this test case breaks:
def test_serialize_floordiv_ranges(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x.view(-1, x.shape[0] - 1)
ep = torch.export._trace._export(
Foo(),
(torch.randn(4, 6),),
dynamic_shapes={
"x": (Dim("dx"), Dim("dy")),
},
allow_complex_guards_as_runtime_asserts=True,
)
print(ep)
torch.export.save(ep, "test.pt")
loaded_ep = torch.export.load("test.pt")
Error:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]"):
#
sym_size_int_4: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_5: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
mul_3: "Sym(s0*s1)" = sym_size_int_4 * sym_size_int_5; sym_size_int_5 = None
add_4: "Sym(s0 - 1)" = -1 + sym_size_int_4
mod_3: "Sym(Mod(s0*s1, s0 - 1))" = mul_3 % add_4
eq_7: "Sym(Eq(Mod(s0*s1, s0 - 1), 0))" = mod_3 == 0; mod_3 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(Mod(s0*s1, s0 - 1), 0) on node 'eq_7'"); eq_7 = _assert_scalar_default = None
floordiv_1: "Sym(((s0*s1)//(s0 - 1)))" = mul_3 // add_4
eq_8: "Sym(Eq(((s0*s1)//(s0 - 1)), s0))" = floordiv_1 == sym_size_int_4
sym_not_3: "Sym(Ne(((s0*s1)//(s0 - 1)), s0))" = torch.sym_not(eq_8); eq_8 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(sym_not_3, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), s0) on node 'sym_not_3'"); sym_not_3 = _assert_scalar_default_1 = None
mod_4: "Sym(Mod(s0, ((s0*s1)//(s0 - 1))))" = sym_size_int_4 % floordiv_1
eq_9: "Sym(Eq(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = mod_4 == 0
sym_not_4: "Sym(Ne(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = torch.sym_not(eq_9); eq_9 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(sym_not_4, "Runtime assertion failed for expression Ne(Mod(s0, ((s0*s1)//(s0 - 1))), 0) on node 'sym_not_4'"); sym_not_4 = _assert_scalar_default_2 = None
mod_5: "Sym(Mod(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 % floordiv_1
eq_10: "Sym(Eq(Mod(s0*s1, ((s0*s1)//(s0 - 1))), 0))" = mod_5 == 0; mod_5 = None
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_10, "Runtime assertion failed for expression Eq(Mod(s0*s1, ((s0*s1)//(s0 - 1))), 0) on node 'eq_10'"); eq_10 = _assert_scalar_default_3 = None
eq_11: "Sym(Eq(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 == floordiv_1
sym_not_5: "Sym(Ne(s0*s1, ((s0*s1)//(s0 - 1))))" = torch.sym_not(eq_11); eq_11 = None
_assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(sym_not_5, "Runtime assertion failed for expression Ne(s0*s1, ((s0*s1)//(s0 - 1))) on node 'sym_not_5'"); sym_not_5 = _assert_scalar_default_4 = None
eq_12: "Sym(Eq(((s0*s1)//(s0 - 1)), 0))" = floordiv_1 == 0
sym_not_6: "Sym(Ne(((s0*s1)//(s0 - 1)), 0))" = torch.sym_not(eq_12); eq_12 = None
_assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(sym_not_6, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), 0) on node 'sym_not_6'"); sym_not_6 = _assert_scalar_default_5 = None
floordiv_2: "Sym(((s0*s1)//(((s0*s1)//(s0 - 1)))))" = mul_3 // floordiv_1
mul_4: "Sym((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))))" = floordiv_1 * floordiv_2
eq_13: "Sym(Eq((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0))" = mul_4 == 0; mul_4 = None
sym_not_7: "Sym(Ne((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0))" = torch.sym_not(eq_13); eq_13 = None
_assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(sym_not_7, "Runtime assertion failed for expression Ne((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0) on node 'sym_not_7'"); sym_not_7 = _assert_scalar_default_6 = None
eq_14: "Sym(Eq(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1))" = floordiv_2 == 1; floordiv_2 = None
sym_not_8: "Sym(Ne(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1))" = torch.sym_not(eq_14); eq_14 = None
_assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(sym_not_8, "Runtime assertion failed for expression Ne(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1) on node 'sym_not_8'"); sym_not_8 = _assert_scalar_default_7 = None
eq_15: "Sym(Eq(((s0*s1)//(s0 - 1)), 1))" = floordiv_1 == 1
sym_not_9: "Sym(Ne(((s0*s1)//(s0 - 1)), 1))" = torch.sym_not(eq_15); eq_15 = None
_assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(sym_not_9, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), 1) on node 'sym_not_9'"); sym_not_9 = _assert_scalar_default_8 = None
# No stacktrace found for following nodes
eq_1: "Sym(Eq(((s0*s1)//(s0 - 1)), s0))" = floordiv_1 == sym_size_int_4; sym_size_int_4 = eq_1 = None
eq_2: "Sym(Eq(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = mod_4 == 0; mod_4 = eq_2 = None
eq_4: "Sym(Eq(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 == floordiv_1; mul_3 = floordiv_1 = eq_4 = None
# File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:2786 in forward, code: return x.view(-1, x.shape[0] - 1)
view: "f32[((s0*s1)//(s0 - 1)), ((s0*s1)//(((s0*s1)//(s0 - 1))))]" = torch.ops.aten.view.default(x, [-1, add_4]); x = add_4 = None
return (view,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='view'), target=None)])
Range constraints: {s0: VR[0, int_oo], s1: VR[0, int_oo]}
...
======================================================================
ERROR: test_serialize_floordiv_ranges_non_strict (caffe2.test.export.test_export_nonstrict.NonStrictExportTestExport)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/testing.py", line 232, in _fn
return fn(*args, **kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py", line 2798, in test_serialize_floordiv_ranges
loaded_ep = torch.export.load("test.pt")
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/export/__init__.py", line 569, in load
ep = deserialize(artifact, expected_opset_version)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 2447, in deserialize
.deserialize(
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 2326, in deserialize
.deserialize(
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1910, in deserialize
self.deserialize_graph(serialized_graph_module.graph)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1616, in deserialize_graph
meta_val = self.deserialize_tensor_meta(tensor_value)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1583, in deserialize_tensor_meta
torch.empty_strided(
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1241, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1695, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1342, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 2019, in _dispatch_impl
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_impls.py", line 176, in constructors
r = func(*args, **new_kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 720, in __call__
return self._op(*args, **kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 480, in expect_size
r = b.expect_true(file, line)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 465, in expect_true
return self.guard_bool(file, line)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 449, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/recording.py", line 262, in wrapper
return retlog(fn(*args, **kwargs))
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 5256, in evaluate_expr
return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 5332, in _evaluate_expr
static_expr = self._maybe_evaluate_static(expr,
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 1739, in wrapper
return fn_cache(self, *args, **kwargs)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 4687, in _maybe_evaluate_static
r = _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only, size_oblivious)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 1607, in _maybe_evaluate_static_worker
out = bound_sympy(new_expr, new_range_env)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 1090, in bound_sympy
return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 185, in sympy_interp
return _run_sympy_handler(
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 155, in _run_sympy_handler
r = handler(*args)
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 669, in pow_by_natural
a, b & ValueRanges(0, int_oo), PowByNatural
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 264, in __and__
return ValueRanges(
File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 164, in __init__
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
torch.utils._sympy.value_ranges.ValueRangeError: Invalid ranges [0:-1]
Versions
.
cc @ezyang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @chauhang
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: dynamic shapesoncall: exportoncall: pt2