-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
We've came across a NaN in the backward pass of Flex Attention. In order to trigger it, you need all of the following:
- BlockMask slightly larger than actual first run's seq_len
- return_lse = True
- Torch compile
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device('cuda')
def causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
return causal_mask
def create_qkv(seq_len):
q = torch.randn(1, 12, seq_len, 64, requires_grad=True)
k = torch.randn(1, 12, seq_len, 64, requires_grad=True)
v = torch.randn(1, 12, seq_len, 64, requires_grad=True)
return q, k, v
causal_block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=4096, KV_LEN=4096)
def forward(q, k, v, mask):
out, lse = flex_attention(q, k, v, block_mask=mask, return_lse=True)
return out * torch.exp(lse)[..., None]
def repro(seq_len, backend):
print("Repro at seq_len", seq_len)
compiled = torch.compile(forward, backend=backend)
for i in range(10):
q, k, v = create_qkv(seq_len)
out = compiled(q, k, v, causal_block_mask)
loss = out.sum()
loss.backward()
assert not q.grad.isnan().any(), "Q has NaN"
assert not k.grad.isnan().any(), "K has NaN"
assert not v.grad.isnan().any(), "V has NaN"
print("Done")
for backend in ["inductor", "eager"]:
print(backend, "4096 then 4096 - 128")
torch.compiler.reset()
repro(4096, backend)
repro(4096 - 128, backend)
print(backend, "just 4096 - 128")
torch.compiler.reset()
try:
repro(4096 - 128, backend)
except AssertionError as e:
print("ERROR", e)inductor 4096 then 4096 - 128
Repro at seq_len 4096
Done
Repro at seq_len 3968
Done
inductor just 4096 - 128
Repro at seq_len 3968
ERROR K has NaN
eager 4096 then 4096 - 128
Repro at seq_len 4096
Done
Repro at seq_len 3968
Done
eager just 4096 - 128
Repro at seq_len 3968
Done
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
Versions
'2.6.0.dev20240916+cu124'
Metadata
Metadata
Assignees
Labels
high prioritymodule: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module