-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Hello all,
I have experienced a similar error as this. Since I cannot post my stack trace due to privacy reasons, I wanted to raise visibility to this post on PyTorch Discuss
.
I’ve been experimenting with the new flex_attention module and encountered an issue when trying to integrate it with DistributedDataParallel (DDP). Since flex_attention is a higher-order function, it seems to conflict with DDP’s optimizer.
Below is a minimal example of my current setup:
import os
import time
import math
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.attention.flex_attention import flex_attention
class Model(torch.nn.Module):
def __init__(self, S, H, D):
super().__init__()
self.S = S
self.H = H
self.D = D
alibi_bias = self.generate_alibi_bias(H)
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
self.attention = flex_attention
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
self.project_v = torch.nn.Linear(H * D, H * D)
def forward(self, hidden_states):
batch_size, _, _ = hidden_states.size()
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
query = query.view(self.S, batch_size, self.H, self.D)
query = query.permute(1, 2, 0, 3)
key = key.view(self.S, batch_size, self.H, self.D)
key = key.permute(1, 2, 0, 3)
value = self.project_v(hidden_states)
value = value.view(self.S, batch_size, self.H, self.D)
value = value.permute(1, 2, 0, 3)
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
def generate_alibi_bias(self, num_heads):
alibi_bias = [math.exp2(-((i + 1) * 8.0) / num_heads) for i in range(num_heads)]
return torch.tensor(alibi_bias)
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
bias = (q_idx - kv_idx) * self.alibi_bias[h]
return score + bias
if __name__ == "__main__":
B = 64
H = 12
S = 512
D = 64
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
model = Model(S, H, D)
model.to(device)
model = DistributedDataParallel(model, device_ids=[local_rank])
torch.compile(model)
for i in range(100):
start = time.perf_counter()
hidden_states = torch.randn(B, S, H * D).to(device)
attention_scores = model(hidden_states)
torch.cuda.synchronize()
print(f"{i}: {time.perf_counter() - start:.4f}")I run the script using the following command:
torchrun --standalone --nnodes=1 --nproc_per_node=1 flex_attention_test.py
[rank0]: File "/home/colibri/mambaforge/envs/pytorch2_5/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1457, in _call_user_compiler
[rank0]: raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank0]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
Disabling the DDP optimizer resolves the error but results in significant performance degradation.
I’m seeking guidance on whether there’s a proper way to use flex_attention or similar higher-order operations in conjunction with DDP without sacrificing performance. Any advice or insights would be greatly appreciated.
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @yanboliang @BoyuanFeng
Versions
34 GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
33 Nvidia driver version: 535.54.03
32 cuDNN version: Could not collect
31 HIP runtime version: N/A
30 MIOpen runtime version: N/A
29 Is XNNPACK available: True
28
27 CPU:
26 Architecture: x86_64
25 CPU op-mode(s): 32-bit, 64-bit
24 Byte Order: Little Endian
23 Address sizes: 48 bits physical, 48 bits virtual
22 CPU(s): 24
21 On-line CPU(s) list: 0-23
20 Thread(s) per core: 1
19 Core(s) per socket: 24
18 Socket(s): 1
17 NUMA node(s): 1
16 Vendor ID: AuthenticAMD
15 CPU family: 25
14 Model: 1
13 Model name: AMD EPYC 7V13 64-Core Processor
12 Stepping: 1
11 CPU MHz: 2445.434
10 BogoMIPS: 4890.86
9 Hypervisor vendor: Microsoft
8 Virtualization type: full
7 L1d cache: 768 KiB
6 L1i cache: 768 KiB
5 L2 cache: 12 MiB
4 L3 cache: 96 MiB
3 NUMA node0 CPU(s): 0-23
2 Vulnerability Itlb multihit: Not affected
1 Vulnerability L1tf: Not affected
5052 Vulnerability Mds: Not affected
1 Vulnerability Meltdown: Not affected
2 Vulnerability Mmio stale data: Not affected
3 Vulnerability Retbleed: Not affected
4 Vulnerability Spec store bypass: Vulnerable
5 Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
6 Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
7 Vulnerability Srbds: Not affected
8 Vulnerability Tsx async abort: Not affected
9 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefet
10 ch osvw topoext perfctr_core invpcid_single vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat umip vaes vpclmulqdq rdpid fsrm
11
12 Versions of relevant libraries:
13 [pip3] flake8==5.0.4
14 [pip3] mypy==1.9.0
15 [pip3] mypy-extensions==1.0.0
16 [pip3] numpy==1.26.4
17 [pip3] onnx==1.17.0
18 [pip3] optree==0.13.0
19 [pip3] pytest-flake8==1.2.2
20 [pip3] pytorch-ignite==0.5.0.post2
21 [pip3] pytorch-lightning==2.4.0
22 [pip3] pytorch-metric-learning==2.6.0
23 [pip3] pytorch-triton==3.1.0+cf34004b8a
24 [pip3] torch==2.6.0.dev20241007+cu121
25 [pip3] torch-audiomentations==0.11.1
26 [pip3] torch_pitch_shift==1.2.5
27 [pip3] torch-stoi==0.2.3
28 [pip3] torchaudio==2.5.0.dev20241007+cu121
29 [pip3] torchcde==0.2.5
30 [pip3] torchcfm==1.0.6
31 [pip3] torchdiffeq==0.2.2
32 [pip3] torchdyn==1.0.6
33 [pip3] torcheval==0.0.7
34 [pip3] torchmetrics==1.4.2
35 [pip3] torchsde==0.2.6
36 [pip3] torchvision==0.20.0.dev20241007+cu121
37 [pip3] triton==2.3.0
38 [conda] blas 1.0 mkl conda-forge
39 [conda] ignite 0.5.0.post2 py_0 pytorch
40 [conda] libblas 3.9.0 16_linux64_mkl conda-forge
41 [conda] libcblas 3.9.0 16_linux64_mkl conda-forge
42 [conda] liblapack 3.9.0 16_linux64_mkl conda-forge
43 [conda] libopenvino-pytorch-frontend 2024.3.0 he02047a_0 conda-forge
44 [conda] mkl 2022.1.0 hc2b9512_224
45 [conda] numpy 1.26.4 py311h64a7726_0 conda-forge
46 [conda] optree 0.13.0 pypi_0 pypi
47 [conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
48 [conda] pytorch-lightning 2.4.0 pyhd8ed1ab_0 conda-forge
49 [conda] pytorch-metric-learning 2.6.0 pypi_0 pypi
50 [conda] pytorch-mutex 1.0 cuda pytorch
51 [conda] pytorch-triton 3.1.0+cf34004b8a pypi_0 pypi
52 [conda] torch 2.6.0.dev20241007+cu121 pypi_0 pypi
53 [conda] torch-audiomentations 0.11.1 pypi_0 pypi
54 [conda] torch-pitch-shift 1.2.5 pypi_0 pypi
55 [conda] torch-stoi 0.2.3 pypi_0 pypi
56 [conda] torchaudio 2.5.0.dev20241007+cu121 pypi_0 pypi
57 [conda] torchcde 0.2.5 pypi_0 pypi
58 [conda] torchcfm 1.0.6 pypi_0 pypi
59 [conda] torchdiffeq 0.2.2 pyhd8ed1ab_0 conda-forge
60 [conda] torchdyn 1.0.6 pypi_0 pypi
61 [conda] torcheval 0.0.7 pypi_0 pypi
62 [conda] torchmetrics 1.4.2 pyhd8ed1ab_0 conda-forge
63 [conda] torchsde 0.2.6 pypi_0 pypi
64 [conda] torchtriton 2.3.0 py311 pytorch
65 [conda] torchvision 0.20.0.dev20241007+cu121 pypi_0 pypi