Skip to content

Flex attention with mask depending on queries and keys lengths (or how to implement causal_lower_right masking) #137779

@janchorowski

Description

@janchorowski

🐛 Describe the bug

I tried to implement the causal_lower_right masking in flex attention. This requires the masking function to know the difference in lengths of keys and queries:

QL = query.size(2)
KL = key.size(2)
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx - QL >= kv_idx - KL

It is easy to use it with flex attention and it works on the first call to flex attention (regardless of using torch.compile on it or not). However, it fails on a call with differently shaped query and key matrices.

I don't know if the usage of queries and keys shape is allowed. If it is, then the second call shouldn't fail. If it is not allowed, then how can one implement causal_lower_right masking, which requires knowing the shapes?

Full reproduction code:

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

def causal_attention(
    query,
    key,
    value,
):
    # all shapes  Bs x Nh x Len x Dim
    B = query.size(0)
    H = query.size(1)
    QL = query.size(2)
    KL = key.size(2)

    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx - QL >= kv_idx - KL

    block_mask = create_block_mask(causal_mask, B, H, QL, KL, device=query.device)
    return flex_attention(
        query,
        key,
        value,
        None,
        block_mask,
    )


def test(ql, kl):
    bs = 32
    nh = 8
    hd = 64
    q = torch.rand(
        bs, nh, ql, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
    )
    k = torch.rand(
        bs, nh, kl, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
    )
    v = torch.rand(
        bs, nh, kl, hd, dtype=torch.bfloat16, device="cuda", requires_grad=True
    )
    causal_attention(q, k, v)
    print(f"test({ql}, {kl}) worked")


print("torch.__version__", torch.__version__)

# First calls always succeed.
test(512, 512)
test(512, 512)
# These calls fail, unless the above are commented out. 
test(512, 1024)
test(512, 1024)
test(512, 512)

Traceback:

torch.__version__ 2.6.0.dev20241009
test(512, 512) worked
test(512, 512) worked
Traceback (most recent call last):
  File "/home/janek/projects/llm_ng/flex_trouble.py", line 52, in <module>
    test(512, 1024)
  File "/home/janek/projects/llm_ng/flex_trouble.py", line 42, in test
    causal_attention(q, k, v)
  File "/home/janek/projects/llm_ng/flex_trouble.py", line 20, in causal_attention
    return flex_attention(
  File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 1113, in flex_attention
    out, lse = torch.compile(
  File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in _fn
    return fn(*args, **kwargs)
  File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 1100, in _flex_attention_hop_wrapper
    def _flex_attention_hop_wrapper(*args, **kwargs):
  File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 654, in _fn
    return fn(*args, **kwargs)
  File "<eval_with_key>.9", line 28, in forward
  File "/mnt/scratch/janek/pixi/babydragon-12050631407633866471/envs/nightly/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 113, in __call__
    raise RuntimeError("Other buffers must be tensors.")
RuntimeError: Other buffers must be tensors.

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241009
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

cc @zou3519 @bdhirsh @penguinwu @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng @ezyang @chauhang @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions