Skip to content

[Performance] [CuDNN-Attention] CuDNN backend should return the output in the same stride order as input Query #138340

@drisspg

Description

@drisspg

Summary

This can have large performance impact in real Attention modules.

The most common pattern (derived from nano-gpt)

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

from torch.nn.attention import bias, sdpa_kernel, SDPBackend

@dataclass
class Config:
    n_embd: int = 512
    n_head: int = 8
    n_layer: int = 6
    n_ctx: int = 2048
    bias: bool = False

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)

        # HERE, WE NEED THIS CONTIGUOUS TO BE A NO-OP
        # y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = y.transpose(1, 2).view(B, T, C)
        y = self.c_proj(y)
        return y

def test_attention(backend: SDPBackend):
    config = Config()
    Attention = CausalSelfAttention(config).to("cuda", dtype=torch.float16)
    sample_input = torch.randn(1, 2048, config.n_embd, device="cuda", dtype = torch.float16)
    with sdpa_kernel(backend):
        try:
            out = Attention(sample_input)
            print("ALL GOOD")
        except RuntimeError as e:
            print("❗ NOT GOOD ❗")
            print(e)

if __name__ == "__main__":
    width = 100
    print("SDPA-Flash".center(width, "-"))
    test_attention(SDPBackend.FLASH_ATTENTION)
    print("SDPA-CuDNN".center(width, "-"))
    test_attention(SDPBackend.CUDNN_ATTENTION)

Output

---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
❗ NOT GOOD ❗
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: cudnnRelated to torch.backends.cudnn, and CuDNN supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions