[ATen][NATIVE][CUDA] Allow all 10.x compute capabilities for using vec8 kernel#174362
[ATen][NATIVE][CUDA] Allow all 10.x compute capabilities for using vec8 kernel#174362Aidyn-A wants to merge 5 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174362
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 Unrelated FailuresAs of commit 44aa38b with merge base c1c6051 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { | ||
| if constexpr (vec_size == 8) { | ||
| #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 | ||
| #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 |
There was a problem hiding this comment.
can we make it a bit more future proof, so that next time we don't have to painfully compare kernel names?
There was a problem hiding this comment.
Sure, I have made it work for all 10.x. It is not that beneficial to compile vec8 on 11.0 and 12.x, so I omitted them. I will remove all conditions on __CUDA_ARCH__ when we no longer need to maintain CUDA 12.x, so we can take advantage of the binary size compression in CUDA 13+ builds.
| __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { | ||
| if constexpr (vec_size == 8) { | ||
| #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 | ||
| #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ / 100 == 10 || __CUDA_ARCH_FAMILY_SPECIFIC__ == 1000 |
There was a problem hiding this comment.
I don't think we need __CUDA_ARCH_FAMILY_SPECIFIC__ macro here because the kernel doesn't use any family-specific instructions, and regular __CUDA_ARCH__ will be set even if someone compiles with 1xxf https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html#cuda-arch-specific-and-cuda-arch-family-specific
| cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index()); | ||
| const int computeCapability = p->major * 10 + p->minor; | ||
| if (computeCapability != 90 && computeCapability != 100) { | ||
| if (p->major != 9 && p->major != 10) { |
There was a problem hiding this comment.
If this line and ifdef in the kernel go out of sync, and you keep vec_size 8 for an arch that's excluded by ifdef, you'll get an empty kernel. Can you find a way to make sure this doesn't happen? Even
There was a problem hiding this comment.
There are two things I can do simultaneously:
- Raise an error from the kernel if it attempts to call an empty part of kernel.
- Implement a custom linter that compares this line vs
#if __CUDA_ARCH__ ...line above, so the numbers must match.
Or just add a big fat comment, so anyone who modifies this file must keep the arches in sync.
There was a problem hiding this comment.
Adding device assert in the kernel should suffice
|
Mac OS build failure is unrelated. @pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 4 checks: trunk / macos-py3-arm64 / build, trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx950.4), trunk / linux-jammy-cuda13.0-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This will allow
sm_103devices call vec8 kernels.Verification script:
Before:
After:
cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @manuelcandales @angelayi