-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: cudnnRelated to torch.backends.cudnn, and CuDNN supportRelated to torch.backends.cudnn, and CuDNN supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
Summary
This can have large performance impact in real Attention modules.
The most common pattern (derived from nano-gpt)
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from torch.nn.attention import bias, sdpa_kernel, SDPBackend
@dataclass
class Config:
n_embd: int = 512
n_head: int = 8
n_layer: int = 6
n_ctx: int = 2048
bias: bool = False
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
# HERE, WE NEED THIS CONTIGUOUS TO BE A NO-OP
# y = y.transpose(1, 2).contiguous().view(B, T, C)
y = y.transpose(1, 2).view(B, T, C)
y = self.c_proj(y)
return y
def test_attention(backend: SDPBackend):
config = Config()
Attention = CausalSelfAttention(config).to("cuda", dtype=torch.float16)
sample_input = torch.randn(1, 2048, config.n_embd, device="cuda", dtype = torch.float16)
with sdpa_kernel(backend):
try:
out = Attention(sample_input)
print("ALL GOOD")
except RuntimeError as e:
print("❗ NOT GOOD ❗")
print(e)
if __name__ == "__main__":
width = 100
print("SDPA-Flash".center(width, "-"))
test_attention(SDPBackend.FLASH_ATTENTION)
print("SDPA-CuDNN".center(width, "-"))
test_attention(SDPBackend.CUDNN_ATTENTION)Output
---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
❗ NOT GOOD ❗
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @mikaylagawarecki
hkchengrex
Metadata
Metadata
Assignees
Labels
high prioritymodule: cudnnRelated to torch.backends.cudnn, and CuDNN supportRelated to torch.backends.cudnn, and CuDNN supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module