Skip to content

[FlexAttention] Using FlexAttention with DDP complains about a "higher order optimizer" #137481

@moinnadeem

Description

@moinnadeem

🐛 Describe the bug

Hello all,

I have experienced a similar error as this. Since I cannot post my stack trace due to privacy reasons, I wanted to raise visibility to this post on PyTorch Discuss
.

I’ve been experimenting with the new flex_attention module and encountered an issue when trying to integrate it with DistributedDataParallel (DDP). Since flex_attention is a higher-order function, it seems to conflict with DDP’s optimizer.

Below is a minimal example of my current setup:

import os
import time
import math

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.attention.flex_attention import flex_attention

class Model(torch.nn.Module):
    def __init__(self, S, H, D):
        super().__init__()

        self.S = S
        self.H = H
        self.D = D

        alibi_bias = self.generate_alibi_bias(H)
        self.register_buffer("alibi_bias", alibi_bias, persistent=True)
        self.attention = flex_attention

        self.project_qk = torch.nn.Linear(H * D, H * D * 2)
        self.project_v = torch.nn.Linear(H * D, H * D)

    def forward(self, hidden_states):
        batch_size, _, _ = hidden_states.size()

        query, key = self.project_qk(hidden_states).chunk(2, dim=2)
        query = query.view(self.S, batch_size, self.H, self.D)
        query = query.permute(1, 2, 0, 3)

        key = key.view(self.S, batch_size, self.H, self.D)
        key = key.permute(1, 2, 0, 3)

        value = self.project_v(hidden_states)
        value = value.view(self.S, batch_size, self.H, self.D)
        value = value.permute(1, 2, 0, 3)

        return self.attention(query, key, value, score_mod=self.alibi_score_mod)

    def generate_alibi_bias(self, num_heads):
        alibi_bias = [math.exp2(-((i + 1) * 8.0) / num_heads) for i in range(num_heads)]
        return torch.tensor(alibi_bias)

    def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
        bias = (q_idx - kv_idx) * self.alibi_bias[h]
        return score + bias

if __name__ == "__main__":

    B = 64
    H = 12
    S = 512
    D = 64

    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    model = Model(S, H, D)
    model.to(device)
    model = DistributedDataParallel(model, device_ids=[local_rank])
    torch.compile(model)

    for i in range(100):
        start = time.perf_counter()
        hidden_states = torch.randn(B, S, H * D).to(device)
        attention_scores = model(hidden_states)
        torch.cuda.synchronize()
        print(f"{i}: {time.perf_counter() - start:.4f}")

I run the script using the following command:

torchrun --standalone --nnodes=1 --nproc_per_node=1 flex_attention_test.py
[rank0]:   File "/home/colibri/mambaforge/envs/pytorch2_5/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1457, in _call_user_compiler
[rank0]:     raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank0]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.

Disabling the DDP optimizer resolves the error but results in significant performance degradation.

I’m seeking guidance on whether there’s a proper way to use flex_attention or similar higher-order operations in conjunction with DDP without sacrificing performance. Any advice or insights would be greatly appreciated.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @yanboliang @BoyuanFeng

Versions

    34 GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
    33 Nvidia driver version: 535.54.03
    32 cuDNN version: Could not collect
    31 HIP runtime version: N/A
    30 MIOpen runtime version: N/A
    29 Is XNNPACK available: True
    28
    27 CPU:
    26 Architecture:                    x86_64
    25 CPU op-mode(s):                  32-bit, 64-bit
    24 Byte Order:                      Little Endian
    23 Address sizes:                   48 bits physical, 48 bits virtual
    22 CPU(s):                          24
    21 On-line CPU(s) list:             0-23
    20 Thread(s) per core:              1
    19 Core(s) per socket:              24
    18 Socket(s):                       1
    17 NUMA node(s):                    1
    16 Vendor ID:                       AuthenticAMD
    15 CPU family:                      25
    14 Model:                           1
    13 Model name:                      AMD EPYC 7V13 64-Core Processor
    12 Stepping:                        1
    11 CPU MHz:                         2445.434
    10 BogoMIPS:                        4890.86
     9 Hypervisor vendor:               Microsoft
     8 Virtualization type:             full
     7 L1d cache:                       768 KiB
     6 L1i cache:                       768 KiB
     5 L2 cache:                        12 MiB
     4 L3 cache:                        96 MiB
     3 NUMA node0 CPU(s):               0-23
     2 Vulnerability Itlb multihit:     Not affected
     1 Vulnerability L1tf:              Not affected
  5052 Vulnerability Mds:               Not affected
     1 Vulnerability Meltdown:          Not affected
     2 Vulnerability Mmio stale data:   Not affected
     3 Vulnerability Retbleed:          Not affected
     4 Vulnerability Spec store bypass: Vulnerable
     5 Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
     6 Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
     7 Vulnerability Srbds:             Not affected
     8 Vulnerability Tsx async abort:   Not affected
     9 Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefet
    10 ch osvw topoext perfctr_core invpcid_single vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat umip vaes vpclmulqdq rdpid fsrm
    11
    12 Versions of relevant libraries:
    13 [pip3] flake8==5.0.4
    14 [pip3] mypy==1.9.0
    15 [pip3] mypy-extensions==1.0.0
    16 [pip3] numpy==1.26.4
    17 [pip3] onnx==1.17.0
    18 [pip3] optree==0.13.0
    19 [pip3] pytest-flake8==1.2.2
    20 [pip3] pytorch-ignite==0.5.0.post2
    21 [pip3] pytorch-lightning==2.4.0
    22 [pip3] pytorch-metric-learning==2.6.0
    23 [pip3] pytorch-triton==3.1.0+cf34004b8a
    24 [pip3] torch==2.6.0.dev20241007+cu121
    25 [pip3] torch-audiomentations==0.11.1
    26 [pip3] torch_pitch_shift==1.2.5
    27 [pip3] torch-stoi==0.2.3
    28 [pip3] torchaudio==2.5.0.dev20241007+cu121
    29 [pip3] torchcde==0.2.5
    30 [pip3] torchcfm==1.0.6
    31 [pip3] torchdiffeq==0.2.2
    32 [pip3] torchdyn==1.0.6
    33 [pip3] torcheval==0.0.7
    34 [pip3] torchmetrics==1.4.2
    35 [pip3] torchsde==0.2.6
    36 [pip3] torchvision==0.20.0.dev20241007+cu121
    37 [pip3] triton==2.3.0
    38 [conda] blas                      1.0                         mkl    conda-forge
    39 [conda] ignite                    0.5.0.post2                py_0    pytorch
    40 [conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
    41 [conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
    42 [conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
    43 [conda] libopenvino-pytorch-frontend 2024.3.0             he02047a_0    conda-forge
    44 [conda] mkl                       2022.1.0           hc2b9512_224
    45 [conda] numpy                     1.26.4          py311h64a7726_0    conda-forge
    46 [conda] optree                    0.13.0                   pypi_0    pypi
    47 [conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
    48 [conda] pytorch-lightning         2.4.0              pyhd8ed1ab_0    conda-forge
    49 [conda] pytorch-metric-learning   2.6.0                    pypi_0    pypi
    50 [conda] pytorch-mutex             1.0                        cuda    pytorch
    51 [conda] pytorch-triton            3.1.0+cf34004b8a          pypi_0    pypi
    52 [conda] torch                     2.6.0.dev20241007+cu121          pypi_0    pypi
    53 [conda] torch-audiomentations     0.11.1                   pypi_0    pypi
    54 [conda] torch-pitch-shift         1.2.5                    pypi_0    pypi
    55 [conda] torch-stoi                0.2.3                    pypi_0    pypi
    56 [conda] torchaudio                2.5.0.dev20241007+cu121          pypi_0    pypi
    57 [conda] torchcde                  0.2.5                    pypi_0    pypi
    58 [conda] torchcfm                  1.0.6                    pypi_0    pypi
    59 [conda] torchdiffeq               0.2.2              pyhd8ed1ab_0    conda-forge
    60 [conda] torchdyn                  1.0.6                    pypi_0    pypi
    61 [conda] torcheval                 0.0.7                    pypi_0    pypi
    62 [conda] torchmetrics              1.4.2              pyhd8ed1ab_0    conda-forge
    63 [conda] torchsde                  0.2.6                    pypi_0    pypi
    64 [conda] torchtriton               2.3.0                     py311    pytorch
    65 [conda] torchvision               0.20.0.dev20241007+cu121          pypi_0    pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions