Skip to content

FloorDiv/CleanDiv inconsistently applied, impedes axioms #134268

@ezyang

Description

@ezyang

🐛 Describe the bug

import torch
import torch._dynamo
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

torch._dynamo.config.capture_scalar_outputs = True

@torch.compile(backend="eager", fullgraph=True)
def fn(x):
    u49, u50 = x.tolist()
    torch._check_is_size(u49)
    torch._check_is_size(u50)
    torch._check((2*u49) % (u49 + u50) == 0)
    torch._check((2*u49)//(u49 + u50) != 0)
    if guard_size_oblivious((2*u49)//(u49 + u50) == 0):
        return torch.tensor(True)
    else:
        return torch.tensor(False)

fn(torch.tensor([3, 3]))

Root cause is in one case we have CleanDiv but in the other we have FloorDiv

diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 3d9aadbe078..a6ca484aae1 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -89,6 +89,28 @@ import sympy
 from sympy.printing.str import StrPrinter
 from sympy.printing.precedence import precedence, PRECEDENCE
 
+
+from sympy import Symbol, Expr, Basic
+from sympy.printing.str import StrPrinter
+
+def debug_print(expr):
+    class DebugPrinter(StrPrinter):
+        def _print_Symbol(self, expr):
+            return f"{expr.name}_{id(expr)}"
+
+    printer = DebugPrinter()
+
+    def _recursive_print(expr):
+        if isinstance(expr, Symbol):
+            return printer._print_Symbol(expr)
+        elif isinstance(expr, (Expr, Basic)):
+            return type(expr).__name__ + '(' + ', '.join(_recursive_print(arg) for arg in expr.args) + ')'
+        else:
+            return str(expr)
+
+    return _recursive_print(expr)
+
+
 class GuardOnDataDependentSymNode(RuntimeError):
     cond: sympy.Expr
 
@@ -4471,6 +4493,9 @@ class ShapeEnv:
             if e.free_symbols.issubset(expr.free_symbols):
                 subst.update(dict(self.get_implications(e)))
 
+        if str(expr) == "Eq(((2*u0)//(u0 + u1)), 0)":
+            breakpoint()
+
         expr = expr.xreplace(subst)
 
         symbols = tuple(expr.free_symbols)

and then:

(/home/ezyang/local/b/pytorch-env) [[email protected] ~/local/b/pytorch (585c049f)]$ TORCH_LOGS=dynamic python a.py
I0822 12:28:31.897000 498236 torch/fx/experimental/symbolic_shapes.py:3349] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] at a.py:9 in fn (_subclasses/fake_impls.py:389 in local_scalar_dense)
I0822 12:28:31.898000 498236 torch/fx/experimental/symbolic_shapes.py:628] [0/0] compute_unbacked_bindings [u0]
I0822 12:28:31.899000 498236 torch/fx/experimental/symbolic_shapes.py:3349] [0/0] create_unbacked_symint u1 [-int_oo, int_oo] at a.py:9 in fn (_subclasses/fake_impls.py:389 in local_scalar_dense)
I0822 12:28:31.900000 498236 torch/fx/experimental/symbolic_shapes.py:628] [0/0] compute_unbacked_bindings [u1]
I0822 12:28:31.906000 498236 torch/fx/experimental/symbolic_shapes.py:5138] [0/0] runtime_assert u0 >= 0 [guard added] at a.py:10 in fn (_dynamo/utils.py:2117 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
I0822 12:28:31.909000 498236 torch/fx/experimental/symbolic_shapes.py:5138] [0/0] runtime_assert u1 >= 0 [guard added] at a.py:11 in fn (_dynamo/utils.py:2117 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u1 >= 0"
I0822 12:28:32.057000 498236 torch/fx/experimental/symbolic_shapes.py:5138] [0/0] runtime_assert Eq(Mod(2*u0, u0 + u1), 0) [guard added] at a.py:12 in fn (_dynamo/utils.py:2117 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(Mod(2*u0, u0 + u1), 0)"
I0822 12:28:32.070000 498236 torch/fx/experimental/symbolic_shapes.py:5138] [0/0] runtime_assert Ne(((2*u0)//(u0 + u1)), 0) [guard added] at a.py:13 in fn (_dynamo/utils.py:2117 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((2*u0)//(u0 + u1)), 0)"
> /data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py(4499)_maybe_evaluate_static()
-> expr = expr.xreplace(subst)
(Pdb) debug_print(expr)
'Equality(CleanDiv(Mul(Integer(), u0_139748737983936), Add(u0_139748737983936, u1_139748738609328)), Zero())'
(Pdb) substs
*** NameError: name 'substs' is not defined
(Pdb) subst
{0 <= u0: True, u0 < 0: False, 0 < u0 + 1: True, u0 + 1 <= 0: False, 0 <= u1: True, u1 < 0: False, 0 < u1 + 1: True, u1 + 1 <= 0: False, Eq(Mod(2*u0, u0 + u1), 0): True, Eq(0, Mod(2*u0, u0 + u1)): True, Ne(Mod(2*u0, u0 + u1), 0): False, Ne(0, Mod(2*u0, u0 + u1)): False, Mod(2*u0, u0 + u1) <= 0: True, 0 < Mod(2*u0, u0 + u1): False, True: True, False: False, Ne(((2*u0)//(u0 + u1)), 0): True, Ne(0, ((2*u0)//(u0 + u1))): True, Eq(((2*u0)//(u0 + u1)), 0): False, Eq(0, ((2*u0)//(u0 + u1))): False}
(Pdb) [debug_print(e) for k, e in subst.items() if str(e) == "Eq(((2*u0)//(u0 + u1)), 0)"]
[]
(Pdb) [debug_print(e) for k, e in subst.items()]
['BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanTrue()', 'BooleanTrue()', 'BooleanFalse()', 'BooleanFalse()']
(Pdb) [debug_print(k) for k in subst.keys if str(k) == "Eq(((2*u0)//(u0 + u1)), 0)"]
*** TypeError: 'builtin_function_or_method' object is not iterable
(Pdb) [debug_print(k) for k in subst.keys() if str(k) == "Eq(((2*u0)//(u0 + u1)), 0)"]
['Equality(FloorDiv(Mul(Integer(), u0_139748737983936), Add(u0_139748737983936, u1_139748738609328)), Zero())']

Versions

main

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions