Skip to content

ValueRange division breaks with pow_by_natural #136797

@pianpwk

Description

@pianpwk

🐛 Describe the bug

This issue occurs during export deserialization of symint values, when we have a sympy.Expr like a/b, this decomposes into a * PowByNatural(b, -1). The PowByNatural seems to lead to an invalid ValueRange where min > max for 1/b.

For example, this test case breaks:

def test_serialize_floordiv_ranges(self):
    class Foo(torch.nn.Module):
        def forward(self, x):
            return x.view(-1, x.shape[0] - 1)

    ep = torch.export._trace._export(
        Foo(),
        (torch.randn(4, 6),),
        dynamic_shapes={
            "x": (Dim("dx"), Dim("dy")),
        },
        allow_complex_guards_as_runtime_asserts=True,
    )
    print(ep)
    torch.export.save(ep, "test.pt")
    loaded_ep = torch.export.load("test.pt")

Error:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1]"):
             # 
            sym_size_int_4: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_5: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
            mul_3: "Sym(s0*s1)" = sym_size_int_4 * sym_size_int_5;  sym_size_int_5 = None
            add_4: "Sym(s0 - 1)" = -1 + sym_size_int_4
            mod_3: "Sym(Mod(s0*s1, s0 - 1))" = mul_3 % add_4
            eq_7: "Sym(Eq(Mod(s0*s1, s0 - 1), 0))" = mod_3 == 0;  mod_3 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(Mod(s0*s1, s0 - 1), 0) on node 'eq_7'");  eq_7 = _assert_scalar_default = None
            floordiv_1: "Sym(((s0*s1)//(s0 - 1)))" = mul_3 // add_4
            eq_8: "Sym(Eq(((s0*s1)//(s0 - 1)), s0))" = floordiv_1 == sym_size_int_4
            sym_not_3: "Sym(Ne(((s0*s1)//(s0 - 1)), s0))" = torch.sym_not(eq_8);  eq_8 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(sym_not_3, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), s0) on node 'sym_not_3'");  sym_not_3 = _assert_scalar_default_1 = None
            mod_4: "Sym(Mod(s0, ((s0*s1)//(s0 - 1))))" = sym_size_int_4 % floordiv_1
            eq_9: "Sym(Eq(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = mod_4 == 0
            sym_not_4: "Sym(Ne(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = torch.sym_not(eq_9);  eq_9 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(sym_not_4, "Runtime assertion failed for expression Ne(Mod(s0, ((s0*s1)//(s0 - 1))), 0) on node 'sym_not_4'");  sym_not_4 = _assert_scalar_default_2 = None
            mod_5: "Sym(Mod(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 % floordiv_1
            eq_10: "Sym(Eq(Mod(s0*s1, ((s0*s1)//(s0 - 1))), 0))" = mod_5 == 0;  mod_5 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_10, "Runtime assertion failed for expression Eq(Mod(s0*s1, ((s0*s1)//(s0 - 1))), 0) on node 'eq_10'");  eq_10 = _assert_scalar_default_3 = None
            eq_11: "Sym(Eq(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 == floordiv_1
            sym_not_5: "Sym(Ne(s0*s1, ((s0*s1)//(s0 - 1))))" = torch.sym_not(eq_11);  eq_11 = None
            _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(sym_not_5, "Runtime assertion failed for expression Ne(s0*s1, ((s0*s1)//(s0 - 1))) on node 'sym_not_5'");  sym_not_5 = _assert_scalar_default_4 = None
            eq_12: "Sym(Eq(((s0*s1)//(s0 - 1)), 0))" = floordiv_1 == 0
            sym_not_6: "Sym(Ne(((s0*s1)//(s0 - 1)), 0))" = torch.sym_not(eq_12);  eq_12 = None
            _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(sym_not_6, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), 0) on node 'sym_not_6'");  sym_not_6 = _assert_scalar_default_5 = None
            floordiv_2: "Sym(((s0*s1)//(((s0*s1)//(s0 - 1)))))" = mul_3 // floordiv_1
            mul_4: "Sym((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))))" = floordiv_1 * floordiv_2
            eq_13: "Sym(Eq((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0))" = mul_4 == 0;  mul_4 = None
            sym_not_7: "Sym(Ne((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0))" = torch.sym_not(eq_13);  eq_13 = None
            _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(sym_not_7, "Runtime assertion failed for expression Ne((((s0*s1)//(s0 - 1)))*(((s0*s1)//(((s0*s1)//(s0 - 1))))), 0) on node 'sym_not_7'");  sym_not_7 = _assert_scalar_default_6 = None
            eq_14: "Sym(Eq(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1))" = floordiv_2 == 1;  floordiv_2 = None
            sym_not_8: "Sym(Ne(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1))" = torch.sym_not(eq_14);  eq_14 = None
            _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(sym_not_8, "Runtime assertion failed for expression Ne(((s0*s1)//(((s0*s1)//(s0 - 1)))), 1) on node 'sym_not_8'");  sym_not_8 = _assert_scalar_default_7 = None
            eq_15: "Sym(Eq(((s0*s1)//(s0 - 1)), 1))" = floordiv_1 == 1
            sym_not_9: "Sym(Ne(((s0*s1)//(s0 - 1)), 1))" = torch.sym_not(eq_15);  eq_15 = None
            _assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(sym_not_9, "Runtime assertion failed for expression Ne(((s0*s1)//(s0 - 1)), 1) on node 'sym_not_9'");  sym_not_9 = _assert_scalar_default_8 = None
            
            # No stacktrace found for following nodes
            eq_1: "Sym(Eq(((s0*s1)//(s0 - 1)), s0))" = floordiv_1 == sym_size_int_4;  sym_size_int_4 = eq_1 = None
            eq_2: "Sym(Eq(Mod(s0, ((s0*s1)//(s0 - 1))), 0))" = mod_4 == 0;  mod_4 = eq_2 = None
            eq_4: "Sym(Eq(s0*s1, ((s0*s1)//(s0 - 1))))" = mul_3 == floordiv_1;  mul_3 = floordiv_1 = eq_4 = None
            
             # File: /data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py:2786 in forward, code: return x.view(-1, x.shape[0] - 1)
            view: "f32[((s0*s1)//(s0 - 1)), ((s0*s1)//(((s0*s1)//(s0 - 1))))]" = torch.ops.aten.view.default(x, [-1, add_4]);  x = add_4 = None
            return (view,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='view'), target=None)])
Range constraints: {s0: VR[0, int_oo], s1: VR[0, int_oo]}
...
======================================================================
ERROR: test_serialize_floordiv_ranges_non_strict (caffe2.test.export.test_export_nonstrict.NonStrictExportTestExport)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/testing.py", line 232, in _fn
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/caffe2/test/export/test_export.py", line 2798, in test_serialize_floordiv_ranges
    loaded_ep = torch.export.load("test.pt")
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/export/__init__.py", line 569, in load
    ep = deserialize(artifact, expected_opset_version)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 2447, in deserialize
    .deserialize(
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 2326, in deserialize
    .deserialize(
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1910, in deserialize
    self.deserialize_graph(serialized_graph_module.graph)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1616, in deserialize_graph
    meta_val = self.deserialize_tensor_meta(tensor_value)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_export/serde/serialize.py", line 1583, in deserialize_tensor_meta
    torch.empty_strided(
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1241, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1695, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 1342, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_tensor.py", line 2019, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_subclasses/fake_impls.py", line 176, in constructors
    r = func(*args, **new_kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/_ops.py", line 720, in __call__
    return self._op(*args, **kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 480, in expect_size
    r = b.expect_true(file, line)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 465, in expect_true
    return self.guard_bool(file, line)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/sym_node.py", line 449, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/recording.py", line 262, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 5256, in evaluate_expr
    return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 5332, in _evaluate_expr
    static_expr = self._maybe_evaluate_static(expr,
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 1739, in wrapper
    return fn_cache(self, *args, **kwargs)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 4687, in _maybe_evaluate_static
    r = _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only, size_oblivious)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/fx/experimental/symbolic_shapes.py", line 1607, in _maybe_evaluate_static_worker
    out = bound_sympy(new_expr, new_range_env)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 1090, in bound_sympy
    return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in sympy_interp
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 187, in <listcomp>
    [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 185, in sympy_interp
    return _run_sympy_handler(
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/interp.py", line 155, in _run_sympy_handler
    r = handler(*args)
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 669, in pow_by_natural
    a, b & ValueRanges(0, int_oo), PowByNatural
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 264, in __and__
    return ValueRanges(
  File "/data/users/pianpwk/fbsource/buck-out/v2/gen/fbcode/8b975c3cdb061ee6/caffe2/test/__test_export__/test_export#link-tree/torch/utils/_sympy/value_ranges.py", line 164, in __init__
    raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
torch.utils._sympy.value_ranges.ValueRangeError: Invalid ranges [0:-1]

Versions

.

cc @ezyang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @chauhang

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions