Skip to content

Flex attention returns zeros for batch dimensions > 0 in certain cases #139462

@koute

Description

@koute

🐛 Describe the bug

Consider the following code:

import torch
from torch.nn import functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

torch.manual_seed(1234)

def mask_mod(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    return causal_mask

mask_1 = create_block_mask(
    mask_mod = mask_mod,
    B = 2,
    H = None,
    Q_LEN = 128,
    KV_LEN = 128,
    device = "cuda",
)

mask_2 = create_block_mask(
    mask_mod = mask_mod,
    B = 2,
    H = None,
    Q_LEN = 128,
    KV_LEN = 256,
    device = "cuda",
)

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

shape = (2, 1, 2, 16)
q = torch.normal(0.0, 3.0, shape, device = "cuda")
k = torch.normal(0.0, 3.0, shape, device = "cuda")
v = torch.normal(0.0, 3.0, shape, device = "cuda")

y0 = F.scaled_dot_product_attention(q, k, v, is_causal = True)

y1 = flex_attention(q, k, v, block_mask = mask_1)
y2 = flex_attention(q, k, v, block_mask = mask_2)

y3 = flex_attention_compiled(q, k, v, block_mask = mask_1)
y4 = flex_attention_compiled(q, k, v, block_mask = mask_2)

print(y0.sum())
print(y1.sum())
print(y2.sum())
print(y3.sum())
print(y4.sum())

print(y0[1])
print(y4[1])

On my machine it prints out the following:

tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(5.0246, device='cuda:0')
tensor([[[-6.1357, -0.8294,  4.0582, -5.9614,  0.7539, -1.2435,  2.3135,
          -0.1406,  2.1571, -0.5893, -4.0255, -0.4079, -2.3713,  1.6904,
           1.4968, -2.5877],
         [-3.8762, -0.2258,  2.3405, -2.2391, -2.3507, -3.4905,  4.8959,
          -0.5228,  2.9442, -1.1869, -3.0614, -1.0334, -2.9472,  0.0813,
           2.9108,  0.6872]]], device='cuda:0')
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       device='cuda:0')

If flex attention is compiled and KV_LEN when creating the block mask is 256 here (as opposed to 128) then flex attention ignores the second batch and just returns all zero results.

Versions

Relevant versions:

  • Python 3.11.8
  • torch @ https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20241031%2Bcu124-cp311-cp311-linux_x86_64.whl
  • pytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp311-cp311-linux_x86_64.whl
  • triton==3.1.0
  • numpy==1.26.4
  • nvidia-cuda-runtime-cu12==12.4.127
  • Linux
  • RTX 4090

The collect_env.py script doesn't work for me (I'm using Rye managed venv):

  File "collect_env.py", line 448, in run_with_pip
    for line in out.splitlines()
                ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'splitlines'

cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

Labels

module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions