-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
I tried to implement the causal_lower_right masking in flex attention. This requires the masking function to know the difference in lengths of keys and queries:
QL = query.size(2)
KL = key.size(2)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx - QL >= kv_idx - KLIt is easy to use it with flex attention and it works on the first call to flex attention (regardless of using torch.compile on it or not). However, it fails on a call with differently shaped query and key matrices.
I don't know if the usage of queries and keys shape is allowed. If it is, then the second call shouldn't fail. If it is not allowed, then how can one implement causal_lower_right masking, which requires knowing the shapes?
Full reproduction code:
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
def causal_attention(
query,
key,
value,
):
# all shapes Bs x Nh x Len x Dim
B = query.size(0)
H = query.size(1)
QL = query.size(2)
KL = key.size(2)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx - QL >= kv_idx - KL
block_mask = create_block_mask(causal_mask, B, H, QL, KL, device=query.device)
return flex_attention(
query,
key,
value,
None,
block_mask,
)
def test(ql, kl):
bs = 32
nh = 8
hd = 64
q = torch.rand(
bs, nh, ql, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
k = torch.rand(
bs, nh, kl, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
v = torch.rand(
bs, nh, kl, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
causal_attention(q, k, v)
print(f"test({ql}, {kl}) worked")
print("torch.__version__", torch.__version__)
# First calls always succeed.
test(512, 512)
test(512, 512)
# These calls fail, unless the above are commented out.
test(512, 1024)
test(512, 1024)
test(512, 512)Traceback:
torch.__version__ 2.6.0.dev20241009
test(512, 512) worked
test(512, 512) worked
Traceback (most recent call last):
File "/home/janek/projects/llm_ng/flex_trouble.py", line 52, in <module>
test(512, 1024)
File "/home/janek/projects/llm_ng/flex_trouble.py", line 42, in test
causal_attention(q, k, v)
File "/home/janek/projects/llm_ng/flex_trouble.py", line 20, in causal_attention
return flex_attention(
File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 1113, in flex_attention
out, lse = torch.compile(
File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in _fn
return fn(*args, **kwargs)
File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 1100, in _flex_attention_hop_wrapper
def _flex_attention_hop_wrapper(*args, **kwargs):
File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 654, in _fn
return fn(*args, **kwargs)
File "<eval_with_key>.9", line 28, in forward
File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 113, in __call__
raise RuntimeError("Other buffers must be tensors.")
RuntimeError: Other buffers must be tensors.
Versions
Collecting environment information...
PyTorch version: 2.6.0.dev20241009
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
cc @zou3519 @bdhirsh @penguinwu @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng @ezyang @chauhang @ydwu4