Skip to content

sys.maxsize special case doesn't work if you slightly offset the ranges #127396

@ezyang

Description

@ezyang

🐛 Describe the bug

Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7297216510406188

Repro:

import torch
import torch._dynamo.config

torch._dynamo.config.capture_scalar_outputs = True

@torch.compile()
def f(x):
    y = x.item()
    z = torch.randn(y, 2048)
    r = torch.cat([z, torch.randn(2, 2048)])
    return r[:, 0:152]

f(torch.tensor(4))

fails with

(/home/ezyang/local/a/pytorch-env) [[email protected] ~/local/a/pytorch (2f1f15c4)]$ python b.py
E0529 06:47:57.013000 139738791269376 torch/fx/experimental/recording.py:280] [0/0] failed while running evaluate_expr(*(Eq(u0 + 2, 9223372036854775807), None), **{'fx_node': None})
Traceback (most recent call last):
  File "/data/users/ezyang/a/pytorch/b.py", line 13, in <module>
    f(torch.tensor(4))
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py", line 421, in _fn
    return fn(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 1077, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 918, in _convert_frame
    result = inner_convert(
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 456, in _convert_frame_assert
    return _compile(
  File "/data/users/ezyang/a/pytorch/torch/_utils_internal.py", line 83, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/data/users/ezyang/a/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/home/ezyang/local/a/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 799, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 618, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/convert_frame.py", line 564, in transform
    tracer.run()
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/symbolic_convert.py", line 2244, in run
    super().run()
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/symbolic_convert.py", line 886, in run
    while self.step():
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/symbolic_convert.py", line 801, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/symbolic_convert.py", line 2435, in RETURN_VALUE
    self._return(inst)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/symbolic_convert.py", line 2420, in _return
    self.output.compile_subgraph(
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/output_graph.py", line 1095, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/ezyang/local/a/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/output_graph.py", line 1312, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/output_graph.py", line 1403, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/output_graph.py", line 1384, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/data/users/ezyang/a/pytorch/torch/__init__.py", line 1895, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/ezyang/local/a/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/ezyang/a/pytorch/torch/_inductor/compile_fx.py", line 1471, in compile_fx
    return aot_autograd(
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_functorch/aot_autograd.py", line 934, in aot_module_simplified
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_functorch/aot_autograd.py", line 551, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/data/users/ezyang/a/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 162, in inner
    flat_f_outs = f(*flat_f_args)
  File "/data/users/ezyang/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 738, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/data/users/ezyang/a/pytorch/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5388, in run_node
    result = super().run_node(n)
  File "/data/users/ezyang/a/pytorch/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/data/users/ezyang/a/pytorch/torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/sym_node.py", line 413, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/recording.py", line 244, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5141, in evaluate_expr
    raise self._make_data_dependent_error(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0 + 2, 9223372036854775807) (unhinted: Eq(u0 + 2, 9223372036854775807)).  (Size-like symbols: u0)

Potential framework code culprit (scroll up for full backtrace):
  File "/data/users/ezyang/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 738, in functional_call
    out = PropagateUnbackedSymInts(mod).run(

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

While executing %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%r, (slice(None, None, None), slice(0, 152, None))), kwargs = {})
Original traceback:
  File "/data/users/ezyang/a/pytorch/b.py", line 11, in f
    return r[:, 0:152]


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

This doesn't fail without the cat. The reason for the failure is we set the max range for u0 <= sys.maxsize - 1, so when we bump it to u0 + 2, now its upper bound is sys.maxsize + 1 and we can no longer conclude that u0 does not equal sys.maxsize.

main

Versions

main

cc @msaroufim @bdhirsh @anijain2305 @chauhang

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions