-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
torch.compile with the inductor backend errors out with dynamic shapes and DistributedDataParallel. Either a direct error ConstraintViolationError: Constraints violated (L['x'].size()[1])! when using torch._dynamo.mark_dynamic, or recompiling multiple times until the recompile limit is reached due to a "stride mismatch at index 0" compilation error with dynamic=True or dynamic=None.
These errors occur in both PyTorch 2.3 and the latest PyTorch Nightly.
I've created a replication with a simple "transformer" model, with just an embedding layer and linear head layer, so I can vary the shape of the sequence length in the batch. I get the same errors with a full from-scratch transformer with DDP.
I inconsistently get the ConstraintViolationError when using torch._dynamo.mark_dynamic in a non-distributed context with PyTorch 2.3. Specifically, with the Hugging Face Transformers Llama implementation. But I have been unable to replicate it with non-HF code.
Error logs
With my replication script below, compiling a DDP model for dynamic shapes with the recommended torch._dynamo.mark_dynamic instead of using torch.compile(..., dynamic=True) using the following command:
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamicresults with the following ConstraintViolationError
[rank0]: torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
[rank0]: - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (970)
You can turn on logging with --logging, but the dynamo logs don't appear to be that useful compared to other errors I've seen.
torch/_guards.py:261] [0/0] Traceback (most recent call last):
torch/_guards.py:261] [0/0] File "/torch/_guards.py", line 259, in create
torch/_guards.py:261] [0/0] return self.create_fn(builder, self)
torch/_guards.py:261] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0] File "/torch/_dynamo/guards.py", line 1664, in SHAPE_ENV
torch/_guards.py:261] [0/0] guards = output_graph.shape_env.produce_guards(
torch/_guards.py:261] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0] File "/torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
torch/_guards.py:261] [0/0] raise ConstraintViolationError(
torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
torch/_guards.py:261] [0/0] - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (986).
torch/_guards.py:263] [0/0] Created at:
torch/_guards.py:263] [0/0] File "/torch/_dynamo/convert_frame.py", line 499, in transform
torch/_guards.py:263] [0/0] tracer = InstructionTranslator(
torch/_guards.py:263] [0/0] File "/torch/_dynamo/symbolic_convert.py", line 2143, in __init__
torch/_guards.py:263] [0/0] output=OutputGraph(
torch/_guards.py:263] [0/0] File "/torch/_dynamo/output_graph.py", line 309, in __init__
torch/_guards.py:263] [0/0] self.init_ambient_guards()
torch/_guards.py:263] [0/0] File "/torch/_dynamo/output_graph.py", line 448, in init_ambient_guards
torch/_guards.py:263] [0/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
The same command using torch.compile(..., dynamic=True) or torch.compile(..., dynamic=None) and relying on the compiler to detect dynamic shapes
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true
# or
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlenresults in a recompiles error:
torch/_dynamo/convert_frame.py:367] torch._dynamo hit config.cache_size_limit (8) torch._dynamo hit config.cache_size_limit (8)
torch/_dynamo/convert_frame.py:367] function: 'forward' (/replication.py:51)
torch/_dynamo/convert_frame.py:367] last reason: tensor 'L['x']' stride mismatch at index 0. expected 1014, actual 1023
The logging output also doesn't appear to verbose.
torch/_dynamo/guards.py:2546] [__recompiles_verbose] Recompiling function forward in /replication.py:51
torch/_dynamo/guards.py:2546] [__recompiles_verbose] triggered by the following guard failure(s):
torch/_dynamo/guards.py:2546] [__recompiles_verbose] guard 0 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose] - tensor 'L['x']' stride mismatch at index 0. expected 985, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose]
torch/_dynamo/guards.py:2546] [__recompiles_verbose] guard 1 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose] - tensor 'L['x']' stride mismatch at index 0. expected 976, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose]
torch/_dynamo/guards.py:2546] [__recompiles_verbose] guard 2 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose] - tensor 'L['x']' stride mismatch at index 0. expected 1015, actual 1011
I'm happy to add more logging if wanted.
Minified repro
Replication Script
# based on the PyTorch DDP example
import argparse
import logging
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from typing import Tuple
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
torch.set_float32_matmul_precision("high")
def pytorch_logs_to_file(file: str = "pytorch.log"):
torch._logging.set_logs(
dynamo=logging.INFO,
aot=logging.INFO,
inductor=logging.INFO,
dynamic=logging.INFO,
distributed=logging.INFO,
graph_breaks=True,
guards=True,
recompiles=True,
recompiles_verbose=True,
output_code=True,
graph_code=True,
graph=True,
ddp_graphs=True,
)
torch._logging._init_logs(file)
loggers = logging.Logger.manager.loggerDict.keys()
for logger_name in loggers:
if logger_name.startswith("torch"):
logger = logging.getLogger(logger_name)
if isinstance(logger, logging.Logger):
handlers = logger.handlers
for handler in handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
class EmbedHeadModel(nn.Module):
def __init__(self, vocab_size: int, hidden_size: int):
super().__init__()
self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
self.head = nn.Linear(hidden_size, vocab_size)
def forward(self, x: Tensor):
out = self.vocab_embed(x)
out = self.head(out)
return out
def get_batch(
batch_size: Tensor, sequence_length: int, vocab_size: int, device: torch.device, dynamic: bool
) -> Tuple[Tensor, Tensor]:
if dynamic:
input = torch.randint(
0,
vocab_size - 1,
(batch_size, sequence_length - random.randint(0, min(512, sequence_length / 2)) // 8 + 1),
device=device,
)
else:
input = torch.randint(0, vocab_size - 1, (batch_size, sequence_length + 1), device=device)
return input[:, :-1].contiguous(), input[:, 1:].contiguous()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sequence_length", type=int, default=1024)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--variable_seqlen", action="store_true", help="Batches have variable sequence lengths")
parser.add_argument("--dynamic_true", action="store_true", help="Compile with dynamic=True instead of None")
parser.add_argument(
"--use_mark_dynamic", action="store_true", help="Use torch._dynamo.mark_dynamic for dynamic shapes"
)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--iterations", type=int, default=100)
parser.add_argument("--vocab_size", type=int, default=8000)
parser.add_argument("--hidden_size", type=int, default=2048)
parser.add_argument("--logging", action="store_true")
parser.add_argument("--log_name", type=str, default="pytorch.log")
return parser.parse_args()
def train():
args = parse_args()
if args.ddp:
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
else:
rank = 0
if args.logging and rank == 0:
pytorch_logs_to_file(args.log_name)
device_id = rank % torch.cuda.device_count()
model = EmbedHeadModel(args.vocab_size, args.hidden_size).to(device=device_id)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
if args.compile:
model = torch.compile(model, dynamic=True if args.dynamic_true and not args.use_mark_dynamic else None)
if args.ddp:
model = DistributedDataParallel(model, device_ids=[device_id])
model.train()
for _ in range(0, args.iterations):
data, targets = get_batch(
args.batch_size, args.sequence_length, args.vocab_size, device_id, args.variable_seqlen
)
if args.use_mark_dynamic:
torch._dynamo.mark_dynamic(data, index=1)
output = model(data)
loss = criterion(output.view(-1, args.vocab_size), targets.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if args.ddp:
dist.destroy_process_group()
if __name__ == "__main__":
train()Versions
I ran my replication script on fresh conda environments:
conda create -n torchnight python=3.11 pytorch torchvision pytorch-cuda=12.4 -c pytorch-nightly -c nvidia -c conda-forge
conda create -n torch23 python=3.11 pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forgePyTorch Nightly Environment
PyTorch version: 2.4.0.dev20240506
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 23.10 (x86_64)
GCC version: (Ubuntu 13.2.0-4ubuntu3) 13.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.38
Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240506
[pip3] torchvision==0.19.0.dev20240506
[pip3] triton==3.0.0
[conda] blas 1.0 mkl conda-forge
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
[conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] filelock 3.9.0 py311_0 pytorch-nightly
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
[conda] libopenvino-pytorch-frontend 2024.0.0 he02047a_5 conda-forge
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.26.4 py311h64a7726_0 conda-forge
[conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.4.0.dev20240506 py3.11_cuda12.4_cudnn8.9.2_0 pytorch-nightly
[conda] pytorch-cuda 12.4 hc786d27_6 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] requests 2.28.1 py311_0 pytorch-nightly
[conda] torchtriton 3.0.0+45fff310c8 py311 pytorch-nightly
[conda] torchvision 0.19.0.dev20240506 py311_cu124 pytorch-nightly
[conda] urllib3 1.26.14 py311_0 pytorch-nightly
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @bobrenjc93 @bdhirsh @anijain2305