Skip to content

Compiling with Inductor, DDP, and Dynamic Shapes Results in Errors #125641

@warner-benjamin

Description

@warner-benjamin

🐛 Describe the bug

torch.compile with the inductor backend errors out with dynamic shapes and DistributedDataParallel. Either a direct error ConstraintViolationError: Constraints violated (L['x'].size()[1])! when using torch._dynamo.mark_dynamic, or recompiling multiple times until the recompile limit is reached due to a "stride mismatch at index 0" compilation error with dynamic=True or dynamic=None.

These errors occur in both PyTorch 2.3 and the latest PyTorch Nightly.

I've created a replication with a simple "transformer" model, with just an embedding layer and linear head layer, so I can vary the shape of the sequence length in the batch. I get the same errors with a full from-scratch transformer with DDP.

I inconsistently get the ConstraintViolationError when using torch._dynamo.mark_dynamic in a non-distributed context with PyTorch 2.3. Specifically, with the Hugging Face Transformers Llama implementation. But I have been unable to replicate it with non-HF code.

Error logs

With my replication script below, compiling a DDP model for dynamic shapes with the recommended torch._dynamo.mark_dynamic instead of using torch.compile(..., dynamic=True) using the following command:

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

results with the following ConstraintViolationError

[rank0]: torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
[rank0]:   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (970)

You can turn on logging with --logging, but the dynamo logs don't appear to be that useful compared to other errors I've seen.

 torch/_guards.py:261] [0/0] Traceback (most recent call last):
 torch/_guards.py:261] [0/0]   File "/torch/_guards.py", line 259, in create
 torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
 torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 torch/_guards.py:261] [0/0]   File "/torch/_dynamo/guards.py", line 1664, in SHAPE_ENV
 torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
 torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 torch/_guards.py:261] [0/0]   File "/torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
 torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
 torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
 torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (986).
 torch/_guards.py:263] [0/0] Created at:
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/convert_frame.py", line 499, in transform
 torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/symbolic_convert.py", line 2143, in __init__
 torch/_guards.py:263] [0/0]     output=OutputGraph(
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/output_graph.py", line 309, in __init__
 torch/_guards.py:263] [0/0]     self.init_ambient_guards()
 torch/_guards.py:263] [0/0]   File "/torch/_dynamo/output_graph.py", line 448, in init_ambient_guards
 torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

The same command using torch.compile(..., dynamic=True) or torch.compile(..., dynamic=None) and relying on the compiler to detect dynamic shapes

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true
# or
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen

results in a recompiles error:

torch/_dynamo/convert_frame.py:367] torch._dynamo hit config.cache_size_limit (8) torch._dynamo hit config.cache_size_limit (8)
torch/_dynamo/convert_frame.py:367]    function: 'forward' (/replication.py:51)
torch/_dynamo/convert_frame.py:367]    last reason: tensor 'L['x']' stride mismatch at index 0. expected 1014, actual 1023

The logging output also doesn't appear to verbose.

torch/_dynamo/guards.py:2546] [__recompiles_verbose] Recompiling function forward in /replication.py:51
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     triggered by the following guard failure(s):
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 0 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 985, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose] 
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 1 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 976, actual 1011
torch/_dynamo/guards.py:2546] [__recompiles_verbose] 
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     guard 2 failures:
torch/_dynamo/guards.py:2546] [__recompiles_verbose]     - tensor 'L['x']' stride mismatch at index 0. expected 1015, actual 1011

I'm happy to add more logging if wanted.

Minified repro

Replication Script
# based on the PyTorch DDP example

import argparse
import logging
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from typing import Tuple
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel

torch.set_float32_matmul_precision("high")


def pytorch_logs_to_file(file: str = "pytorch.log"):
    torch._logging.set_logs(
        dynamo=logging.INFO,
        aot=logging.INFO,
        inductor=logging.INFO,
        dynamic=logging.INFO,
        distributed=logging.INFO,
        graph_breaks=True,
        guards=True,
        recompiles=True,
        recompiles_verbose=True,
        output_code=True,
        graph_code=True,
        graph=True,
        ddp_graphs=True,
    )
    torch._logging._init_logs(file)

    loggers = logging.Logger.manager.loggerDict.keys()
    for logger_name in loggers:
        if logger_name.startswith("torch"):
            logger = logging.getLogger(logger_name)
            if isinstance(logger, logging.Logger):
                handlers = logger.handlers
                for handler in handlers:
                    if isinstance(handler, logging.StreamHandler):
                        logger.removeHandler(handler)


class EmbedHeadModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x: Tensor):
        out = self.vocab_embed(x)
        out = self.head(out)
        return out


def get_batch(
    batch_size: Tensor, sequence_length: int, vocab_size: int, device: torch.device, dynamic: bool
) -> Tuple[Tensor, Tensor]:
    if dynamic:
        input = torch.randint(
            0,
            vocab_size - 1,
            (batch_size, sequence_length - random.randint(0, min(512, sequence_length / 2)) // 8 + 1),
            device=device,
        )
    else:
        input = torch.randint(0, vocab_size - 1, (batch_size, sequence_length + 1), device=device)
    return input[:, :-1].contiguous(), input[:, 1:].contiguous()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sequence_length", type=int, default=1024)
    parser.add_argument("--ddp", action="store_true")
    parser.add_argument("--compile", action="store_true")
    parser.add_argument("--variable_seqlen", action="store_true", help="Batches have variable sequence lengths")
    parser.add_argument("--dynamic_true", action="store_true", help="Compile with dynamic=True instead of None")
    parser.add_argument(
        "--use_mark_dynamic", action="store_true", help="Use torch._dynamo.mark_dynamic for dynamic shapes"
    )
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--iterations", type=int, default=100)
    parser.add_argument("--vocab_size", type=int, default=8000)
    parser.add_argument("--hidden_size", type=int, default=2048)
    parser.add_argument("--logging", action="store_true")
    parser.add_argument("--log_name", type=str, default="pytorch.log")

    return parser.parse_args()


def train():
    args = parse_args()

    if args.ddp:
        dist.init_process_group("nccl")
        rank = dist.get_rank()
        print(f"Start running basic DDP example on rank {rank}.")
    else:
        rank = 0

    if args.logging and rank == 0:
        pytorch_logs_to_file(args.log_name)

    device_id = rank % torch.cuda.device_count()
    model = EmbedHeadModel(args.vocab_size, args.hidden_size).to(device=device_id)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    if args.compile:
        model = torch.compile(model, dynamic=True if args.dynamic_true and not args.use_mark_dynamic else None)

    if args.ddp:
        model = DistributedDataParallel(model, device_ids=[device_id])

    model.train()

    for _ in range(0, args.iterations):
        data, targets = get_batch(
            args.batch_size, args.sequence_length, args.vocab_size, device_id, args.variable_seqlen
        )
        if args.use_mark_dynamic:
            torch._dynamo.mark_dynamic(data, index=1)

        output = model(data)
        loss = criterion(output.view(-1, args.vocab_size), targets.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if args.ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    train()

Versions

I ran my replication script on fresh conda environments:

conda create -n torchnight python=3.11 pytorch torchvision pytorch-cuda=12.4 -c pytorch-nightly -c nvidia -c conda-forge

conda create -n torch23 python=3.11 pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forge
PyTorch Nightly Environment
PyTorch version: 2.4.0.dev20240506
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 23.10 (x86_64)
GCC version: (Ubuntu 13.2.0-4ubuntu3) 13.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.38

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True


Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240506
[pip3] torchvision==0.19.0.dev20240506
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] libopenvino-pytorch-frontend 2024.0.0             he02047a_5    conda-forge
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4          py311h64a7726_0    conda-forge
[conda] pillow                    9.3.0           py311h3fd9d12_2    pytorch-nightly
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240506 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240506     py311_cu124    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @bobrenjc93 @bdhirsh @anijain2305

Metadata

Metadata

Assignees

Labels

high prioritymodule: ddpIssues/PRs related distributed data parallel trainingmodule: dynamic shapesoncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions