-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Consider the following code:
import torch
from torch.nn import functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
torch.manual_seed(1234)
def mask_mod(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
return causal_mask
mask_1 = create_block_mask(
mask_mod = mask_mod,
B = 2,
H = None,
Q_LEN = 128,
KV_LEN = 128,
device = "cuda",
)
mask_2 = create_block_mask(
mask_mod = mask_mod,
B = 2,
H = None,
Q_LEN = 128,
KV_LEN = 256,
device = "cuda",
)
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
shape = (2, 1, 2, 16)
q = torch.normal(0.0, 3.0, shape, device = "cuda")
k = torch.normal(0.0, 3.0, shape, device = "cuda")
v = torch.normal(0.0, 3.0, shape, device = "cuda")
y0 = F.scaled_dot_product_attention(q, k, v, is_causal = True)
y1 = flex_attention(q, k, v, block_mask = mask_1)
y2 = flex_attention(q, k, v, block_mask = mask_2)
y3 = flex_attention_compiled(q, k, v, block_mask = mask_1)
y4 = flex_attention_compiled(q, k, v, block_mask = mask_2)
print(y0.sum())
print(y1.sum())
print(y2.sum())
print(y3.sum())
print(y4.sum())
print(y0[1])
print(y4[1])On my machine it prints out the following:
tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(-13.8719, device='cuda:0')
tensor(5.0246, device='cuda:0')
tensor([[[-6.1357, -0.8294, 4.0582, -5.9614, 0.7539, -1.2435, 2.3135,
-0.1406, 2.1571, -0.5893, -4.0255, -0.4079, -2.3713, 1.6904,
1.4968, -2.5877],
[-3.8762, -0.2258, 2.3405, -2.2391, -2.3507, -3.4905, 4.8959,
-0.5228, 2.9442, -1.1869, -3.0614, -1.0334, -2.9472, 0.0813,
2.9108, 0.6872]]], device='cuda:0')
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
device='cuda:0')
If flex attention is compiled and KV_LEN when creating the block mask is 256 here (as opposed to 128) then flex attention ignores the second batch and just returns all zero results.
Versions
Relevant versions:
- Python 3.11.8
torch @ https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20241031%2Bcu124-cp311-cp311-linux_x86_64.whlpytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp311-cp311-linux_x86_64.whltriton==3.1.0numpy==1.26.4nvidia-cuda-runtime-cu12==12.4.127- Linux
- RTX 4090
The collect_env.py script doesn't work for me (I'm using Rye managed venv):
File "collect_env.py", line 448, in run_with_pip
for line in out.splitlines()
^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'splitlines'
cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng
Metadata
Metadata
Assignees
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module