Skip to content

Conversation

@pytorchbot
Copy link
Collaborator

Stack from ghstack (oldest at bottom):

Summary

Currently we have a cudnn_order that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

  1. SDPA: CUDNN backend error w/ q_seq_len = 1 #138529
  2. RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph. huggingface/diffusers#9704
  3. [cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:

from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

cc @mikaylagawarecki

# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138587

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 61a03e4 with merge base b7eb725 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet self-requested a review October 22, 2024 14:13
@kit1980 kit1980 merged commit 848e7ac into release/2.5 Oct 22, 2024
@github-actions github-actions bot deleted the cherry-pick-138522-by-pytorch_bot_bot_ branch November 22, 2024 02:09
aostrowski-hbn pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Jan 7, 2025
pytorch/pytorch@v2.5.0...v2.5.1

Squashed new commits are as follows:
* update getting started xpu (pytorch#138090)
* [Cherry-Pick] Use cuda 12.4 pytorch_extra_install_requirements as default (pytorch#138526)
* Don't try to load cufile (pytorch#138539)
* Add link to torch.compile the missing manual in troubleshooting (pytorch#137369)
* Update cpuinfo submodule (pytorch#138600)
* Update doc copyrights to 2024 (pytorch#138650)
* [SDPA-CUDNN] Make CuDNN Attention Opt in (pytorch#138587)
* [MPS] Fix sliced cast (pytorch#138535)
* Disabling amp context when invoking compiler (pytorch#138659)

Change-Id: I3e282e8b4809b97b38605420c64d1bd1b0b938aa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants