Skip to content

Compilation of flex attention with dynamic enabled fails when BlockMask is used #134560

@SamGalanakis

Description

@SamGalanakis

🐛 Describe the bug

When compiling flex attention with dynamic=True I receive the following error. This only happens when a BlockMask is used, it works fine with block_mask = None. With dynamic False it works but recompiles for each batch size change which is what I am trying to avoid for document packing. This is on: pytorch-nightly:2.5.0.dev20240826-cuda12.4-cudnn9-devel.

import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import (
    BlockMask,
    _score_mod_signature,
    create_block_mask
)
from torch import nn

class SelfAttentionLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        n_head: int,
        dropout: float = 0.0,
        bias=False,
    ):
        super().__init__()
        assert (
            dim % n_head == 0
        ), f"dim must be divisible by n_head found: {dim} and {n_head}"
        self.qkv = nn.Linear(dim, 3 * dim, bias=bias)
        self.c_proj = nn.Linear(dim, dim, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = n_head
        self.head_dim = dim // n_head
        self.n_embd = dim
        self.dropout = dropout

    def forward(
        self,
        x,
        score_mod: None | _score_mod_signature = None,
        block_mask: None | BlockMask = None,
    ):
        B, T, C = (
            x.size()
        )  
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        ).contiguous()  ## Contiguous neccessary here https://github.com/pytorch/pytorch/issues/134471 TODO: Check if this is still necessary
        q, k, v = qkv

        y = flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)


        return y


device = "cuda"
compile = True
dynamic = True
compile_block_mask = False

T = 256
E = 512
H = 8

assert T % 128 == 0, "T must be divisible by 128"

layer = SelfAttentionLayer(dim=E, n_head=H).to(device)
if compile:
    layer.compile(mode="default", dynamic=dynamic)
def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


block_mask = create_block_mask(
    mask_mod=causal,
    B=None,
    H=None,
    Q_LEN=T,
    KV_LEN=T,
    device=device,
    _compile=compile_block_mask,
)

for batch_size in range(2, 25):
    x = torch.randn(batch_size, T, E).to(device)
    y = layer(x,block_mask=block_mask)
    loss = y.mean()
    loss.backward()

Error:

backend='inductor' raised:
LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1]),
      origins=OrderedSet([clone, select])
    )
  )
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=512*s0*s1),
      origins=OrderedSet([select_1])
    )
  )
  args[2]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=1024*s0*s1),
      origins=OrderedSet([select_2])
    )
  )
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), s5, s6, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
torch._inductor.codecache.BypassFxGraphCache: Can't cache HigherOrderOperators.

During handling of the above exception, another exception occurred:

torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1]),
      origins=OrderedSet([clone, select])
    )
  )
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=512*s0*s1),
      origins=OrderedSet([select_1])
    )
  )
  args[2]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=1024*s0*s1),
      origins=OrderedSet([select_2])
    )
  )
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), s5, s6, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

The above exception was the direct cause of the following exception:

  File "/workspaces/methylation_prediction/flex.py", line 95, in <module>
    loss = y.mean()
        ^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1]),
      origins=OrderedSet([clone, select])
    )
  )
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=512*s0*s1),
      origins=OrderedSet([select_1])
    )
  )
  args[2]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[3, s0, 8, s1, 64], stride=[512*s0*s1, 512*s1, 64*s1, 64, 1]), data=Pointwise(
          'cuda',
          torch.float32,
          def inner_fn(index):
              i0, i1, i2, i3, i4 = index
              tmp0 = ops.load(buf0, i4 + 64 * i2 + 512 * i0 + 1536 * i3 + 1536 * i1 * s1)
              return tmp0
          ,
          ranges=[3, s0, 8, s1, 64],
          origin_node=clone,
          origins=OrderedSet([clone])
        ))
      ),
      FixedLayout('cuda', torch.float32, size=[s0, 8, s1, 64], stride=[512*s1, 64*s1, 64, 1], offset=1024*s0*s1),
      origins=OrderedSet([select_2])
    )
  )
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0], stride=[s0, s0, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  )), s5, s6, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Collecting environment information...
PyTorch version: 2.5.0.dev20240825+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.4
Libc version: glibc-2.35

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1071-azure-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 24
On-line CPU(s) list: 0-23
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7V13 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 24
Socket(s): 1
Stepping: 1
BogoMIPS: 4890.88
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat umip vaes vpclmulqdq rdpid fsrm
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 768 KiB (24 instances)
L1i cache: 768 KiB (24 instances)
L2 cache: 12 MiB (24 instances)
L3 cache: 96 MiB (3 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240825+cu124
[pip3] torch_scatter==2.1.2
[pip3] torchaudio==2.4.0.dev20240825+cu124
[pip3] torchelastic==0.2.2
[pip3] torcheval==0.0.7
[pip3] torchvision==0.20.0.dev20240825+cu124
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0.dev20240825+cu124 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torchaudio 2.4.0.dev20240825+cu124 pypi_0 pypi
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torcheval 0.0.7 pypi_0 pypi
[conda] torchvision 0.20.0.dev20240825+cu124 pypi_0 pypi

Metadata

Metadata

Assignees

Labels

module: flex attentiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions