Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Nov 5, 2025

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve #166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

cc @csarofeen @ptrblck @xwang233 @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01

@eqy eqy requested a review from syed-ahmed as a code owner November 5, 2025 19:06
@eqy eqy added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label Nov 5, 2025
@eqy eqy requested a review from Aidyn-A as a code owner November 5, 2025 19:06
@eqy eqy added module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) open source release notes: cudnn module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Nov 5, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167111

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit caa7a77 with merge base 5c63946 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Nov 5, 2025
long versionCUDART() const override;
long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
Copy link
Collaborator

@Skylion007 Skylion007 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does Runtime CUDNN frontend matter? It cannot be changed right? It's a compile time include header?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sidecar'd this change in as we'll need it in the near future for SDPA issues that require a cuDNN frontend version to be available for gating. In theory sdp_utils.cpp could be able to access this but I'm not sure I want to include that directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the runtime version be different for cudNNFronteEnd or should it be constexpr?

static bool hasCuDNN() {
return detail::getCUDAHooks().hasCuDNN();
}
static long versionCuDNN() {
Copy link
Collaborator

@Skylion007 Skylion007 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is really compile time? Why no constexpr? Would enable if constexpr logic that would simplify critical code paths in CUDNN dispatch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes see

return CUDNN_VERSION;

other uses of CUDNN_VERSION in the file are macros, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if they are macros they should be propogated with constexpr then. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, CUDNN_FRONTNED has it's equivalent function as constexpr

@eqy
Copy link
Collaborator Author

eqy commented Nov 5, 2025

@Skylion007 are we building with C++20 only? not sure if virtual functions (as these are CUDAHooks) can be constexpr

@eqy
Copy link
Collaborator Author

eqy commented Nov 6, 2025

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Skylion007
Copy link
Collaborator

@Skylion007 are we building with C++20 only? not sure if virtual functions (as these are CUDAHooks) can be constexpr

Ah, wasn't aware of that limitation. Not yet, no. :(

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 5, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator Author

eqy commented Nov 7, 2025

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@eqy
Copy link
Collaborator Author

eqy commented Nov 7, 2025

@pytorchbot cherry-pick --onto release/2.9 --fixes "cuDNN conv3d performance workaround" -c regression

pytorchbot pushed a commit that referenced this pull request Nov 7, 2025
…#167111)

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve #166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

Pull Request resolved: #167111
Approved by: https://github.com/Skylion007

(cherry picked from commit e678450)
@pytorchbot
Copy link
Collaborator

Cherry picking #167111

The cherry pick PR is at #167327 and it is linked with issue cuDNN conv3d performance workaround. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman pushed a commit that referenced this pull request Nov 7, 2025
…#167327)

[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (#167111)

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve #166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

Pull Request resolved: #167111
Approved by: https://github.com/Skylion007

(cherry picked from commit e678450)

Co-authored-by: Eddie Yan <[email protected]>
jovan2009 referenced this pull request in comfyanonymous/ComfyUI Nov 14, 2025
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
…pytorch#167111)

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve pytorch#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

Pull Request resolved: pytorch#167111
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: cpu CPU specific problem (e.g., perf, algorithm) module: cudnn Related to torch.backends.cudnn, and CuDNN support module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source release notes: cudnn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Significant Memory Regression in F.conv3d with bfloat16 Inputs in PyTorch 2.9.0

4 participants