-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
When running a torch.compiled model with DDP with inputs that has growing sequence length, the recompilations happens every time the input shape changes. After 8 recompilations, the cache size limit is reached.
Very similar to this. version 2.5.1+cu124
Notes :
- works without DDP
torch.compilewithdynamic=Noneordynamic=Truehave the same behavior- if the model has only one layer, it works
- if it's the batch size that's growing, it works
Minimal code to reproduce :
import os
import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
torch.set_float32_matmul_precision("high")
D_MODEL = 768
device = "cuda:0"
# setup DDP
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0)
def get_seqlen(it):
return 16+int(2*it)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc_0 = nn.Linear(D_MODEL, D_MODEL)
self.fc_1 = nn.Linear(D_MODEL, D_MODEL)
def forward(self, x):
x = self.fc_0(x)
x = self.fc_1(x)
return x
model = MyModel().to(device)
model = torch.compile(model, dynamic=None)
model = DDP(model, device_ids=[ddp_local_rank])
start_time = time.time()
for epoch in range(100):
print(f"[{(time.time()-start_time):.2f}] epoch {epoch}. len={get_seqlen(epoch)}.")
inputs = torch.randn((16, get_seqlen(epoch), D_MODEL), device=device)
logits = model(inputs)
dist.destroy_process_group()
Run with torchrun --standalone --nproc_per_node=1 repl_ddp.py
You will see that the model gets recompiled each step (because the seq len changes each step), and after 8 compilations cache limit is hit.
(here, the "sequence length" is just a second batch dimension. In my setup where this first occurred (transformer) it is really a sequence length)
Error logs
Here are the relevant excerpts of the output when ran with TORCH_LOGS="+dynamic".
-Without DDP (works, compiles only two times):
this get displayed during the 2nd compilation (len was 16, is now 18):
[0/1] create_symbol s0 = 18 for L['x'].size()[1] [2, int_oo]
[...]
[0/1] eval 12288*s0 < 2147483648 [guard added]
[...]
[0/1] track_symint L['x'].size()[1] s0 RelaxedUnspecConstraint(warn_only=True)
-With DDP (doesn't work, compiles every time the seq len changes):
this get displayed during the 2nd compilation (len was 16, is now 18):
[0/1] create_symbol s0 = 18 for L['x'].size()[1] [2, int_oo]
[...]
[0/1] eval 12288*s0 < 2147483648 [guard added]
[0/1] _update_var_to_range s0 = VR[18, 18] (update)
[0/1] set_replacement s0 = 18 (range_refined_to_singleton) VR[18, 18]
[0/1] eval Eq(13824, 768*s0) [guard added] (_refs/__init__.py:3687 in _reshape_view_helper), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(13824, 768*s0)"
[...]
[0/1] track_symint L['x'].size()[1] 18 RelaxedUnspecConstraint(warn_only=True)
(similar logs for every recompilation)
-With DDP, and TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(13824, 768*s0)" :
Full log of TORCH_LOGS="+dynamic" TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(13824, 768*s0)" nohup torchrun --standalone --nproc_per_node=1 repl_ddp.py :
[0/1] eval Eq(13824, 768*s0) [guard added] (_refs/__init__.py:3687 in _reshape_view_helper)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] Stack (most recent call last):
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/workspace/repl_ddp.py", line 50, in <module>
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] logits = model(inputs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._call_impl(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return forward_call(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self.module(*inputs, **kwargs) # type: ignore[index]
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._call_impl(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return forward_call(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._call_impl(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return forward_call(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1263, in __call__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return hijacked_callback(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] result = self._inner_convert(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return _compile(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return _compile_inner(code, one_graph, hooks, transform)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 87, in wrapper_function
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return function(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] out_code = transform_code_object(code, transform)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] transformations(instructions, code_options)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 634, in transform
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] tracer.run()
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] super().run()
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] while self.step():
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self.dispatch_table[inst.opcode](self, inst)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self._return(inst)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self.output.compile_subgraph(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] compiled_fn = self.call_user_compiler(gm)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._call_user_compiler(gm)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/distributed.py", line 546, in compile_fn
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] submod_compiler.run(*example_inputs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 146, in run
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self.env[node] = self.run_node(node)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/distributed.py", line 285, in run_node
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] out = compiled_submod_real(*new_args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._call_impl(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return forward_call(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/distributed.py", line 154, in forward
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] x = self.submod(*args)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return compiled_fn(full_args)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 308, in runtime_wrapper
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] all_outs = call_func_at_runtime_with_args(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] out = normalize_as_list(f(args))
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 98, in g
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return f(*args)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1593, in forward
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 1116, in __call__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._op(*args, **(kwargs or {}))
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self.dispatch(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1348, in _cached_dispatch_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1943, in _dispatch_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return decomposition_table[func](*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 273, in _fn
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] result = fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 3799, in _reshape_alias
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return aten.view(x, shape)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 1116, in __call__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._op(*args, **(kwargs or {}))
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return fn(*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self.dispatch(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1348, in _cached_dispatch_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1943, in _dispatch_impl
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return decomposition_table[func](*args, **kwargs)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_refs/__init__.py", line 4591, in view
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return _reshape_view_helper(a, *shape, allow_copy=False)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/_refs/__init__.py", line 3687, in _reshape_view_helper
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] if a.is_contiguous():
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/sym_node.py", line 501, in guard_size_oblivious
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] r = self.shape_env.evaluate_expr(
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/recording.py", line 262, in wrapper
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return retlog(fn(*args, **kwargs))
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5304, in _evaluate_expr
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self._log_guard("eval", g, forcing_spec=forcing_spec)
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5106, in _log_guard
[rank0]:I1110 16:29:04.502000 41799 torch/fx/experimental/symbolic_shapes.py:5106] [0/1] self.log.info(
Versions
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper PRO 3995WX 64-Cores
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 4308.3979
CPU min MHz: 2200.0000
BogoMIPS: 5400.06
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 256 MiB (16 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] Could not collect
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @ezyang @bobrenjc93