Skip to content

NaN in Flex Attention backward if BlockMask is larger than first run seq_len when return_lse=True and torch compiled #136232

@cora-codes

Description

@cora-codes

🐛 Describe the bug

We've came across a NaN in the backward pass of Flex Attention. In order to trigger it, you need all of the following:

  • BlockMask slightly larger than actual first run's seq_len
  • return_lse = True
  • Torch compile
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device('cuda')
def causal_mask(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    return causal_mask
def create_qkv(seq_len):
    q = torch.randn(1, 12, seq_len, 64, requires_grad=True)
    k = torch.randn(1, 12, seq_len, 64, requires_grad=True)
    v = torch.randn(1, 12, seq_len, 64, requires_grad=True)
    return q, k, v
causal_block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096)
def forward(q, k, v, mask):
    out, lse = flex_attention(q, k, v, block_mask=mask, return_lse=True)
    return out * torch.exp(lse)[..., None]
def repro(seq_len, backend):
    print("Repro at seq_len", seq_len)
    compiled = torch.compile(forward, backend=backend)
    for i in range(10):
        q, k, v = create_qkv(seq_len)
        out = compiled(q, k, v, causal_block_mask)
        loss = out.sum()
        loss.backward()
        assert not q.grad.isnan().any(), "Q has NaN"
        assert not k.grad.isnan().any(), "K has NaN"
        assert not v.grad.isnan().any(), "V has NaN"
    print("Done")

for backend in ["inductor", "eager"]:
    print(backend, "4096 then 4096 - 128")
    torch.compiler.reset()
    repro(4096, backend)
    repro(4096 - 128, backend)

    print(backend, "just 4096 - 128")
    torch.compiler.reset()
    try:
        repro(4096 - 128, backend)
    except AssertionError as e:
        print("ERROR", e)
inductor 4096 then 4096 - 128
Repro at seq_len 4096
Done
Repro at seq_len 3968
Done
inductor just 4096 - 128
Repro at seq_len 3968
ERROR K has NaN
eager 4096 then 4096 - 128
Repro at seq_len 4096
Done
Repro at seq_len 3968
Done
eager just 4096 - 128
Repro at seq_len 3968
Done

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Versions

'2.6.0.dev20240916+cu124'

Metadata

Metadata

Assignees

Labels

high prioritymodule: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions