Skip to content

[ATen][NATIVE][CUDA] Allow all 10.x compute capabilities for using vec8 kernel#174362

Closed
Aidyn-A wants to merge 5 commits intopytorch:mainfrom
Aidyn-A:vec8_on_sm_103
Closed

[ATen][NATIVE][CUDA] Allow all 10.x compute capabilities for using vec8 kernel#174362
Aidyn-A wants to merge 5 commits intopytorch:mainfrom
Aidyn-A:vec8_on_sm_103

Conversation

@Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Feb 5, 2026

This will allow sm_103 devices call vec8 kernels.
Verification script:

import torch
from torch.profiler import profile, ProfilerActivity

device = torch.device("cuda")

for dtype in (torch.bfloat16, torch.float16,):
    x = torch.randn(1024, device=device, dtype=dtype)
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        y = torch.relu(x)
    stats = prof.key_averages()
    for entry in stats:
        if "at::native::vectorized_elementwise_kernel" in entry.key:
            print(entry.key)

Before:

void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul>)
void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul>)

After:

void at::native::vectorized_elementwise_kernel<8, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul>)
void at::native::vectorized_elementwise_kernel<8, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul>)

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @manuelcandales @angelayi

@Aidyn-A Aidyn-A requested review from eqy and ngimel February 5, 2026 07:30
@Aidyn-A Aidyn-A self-assigned this Feb 5, 2026
@Aidyn-A Aidyn-A added the module: cuda Related to torch.cuda, and CUDA support in general label Feb 5, 2026
@Aidyn-A Aidyn-A requested a review from syed-ahmed as a code owner February 5, 2026 07:30
@Aidyn-A Aidyn-A added the topic: not user facing topic category label Feb 5, 2026
@Aidyn-A Aidyn-A added the module: core aten Related to change to the Core ATen opset label Feb 5, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174362

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 3 Unrelated Failures

As of commit 44aa38b with merge base c1c6051 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Aidyn-A Aidyn-A changed the title [ATen][NATIVE][CUDA] Add 103 to allowed compute capabilities for vec8 [ATen][NATIVE][CUDA] Add 103 to allowed compute capabilities for vec8 kernel Feb 5, 2026
Copy link
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness, should we also add

1000 == __CUDA_ARCH_FAMILY_SPECIFIC__

in case some users build with e.g., TORCH_CUDA_ARCH_LIST=10.0f?

__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
if constexpr (vec_size == 8) {
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it a bit more future proof, so that next time we don't have to painfully compare kernel names?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have made it work for all 10.x. It is not that beneficial to compile vec8 on 11.0 and 12.x, so I omitted them. I will remove all conditions on __CUDA_ARCH__ when we no longer need to maintain CUDA 12.x, so we can take advantage of the binary size compression in CUDA 13+ builds.

__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
if constexpr (vec_size == 8) {
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ / 100 == 10 || __CUDA_ARCH_FAMILY_SPECIFIC__ == 1000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need __CUDA_ARCH_FAMILY_SPECIFIC__ macro here because the kernel doesn't use any family-specific instructions, and regular __CUDA_ARCH__ will be set even if someone compiles with 1xxf https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html#cuda-arch-specific-and-cuda-arch-family-specific

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 9, 2026
@Aidyn-A Aidyn-A requested a review from ngimel February 10, 2026 06:44
@Aidyn-A Aidyn-A changed the title [ATen][NATIVE][CUDA] Add 103 to allowed compute capabilities for vec8 kernel [ATen][NATIVE][CUDA] Allow all 10.x compute capabilities for using vec8 kernel Feb 10, 2026
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
const int computeCapability = p->major * 10 + p->minor;
if (computeCapability != 90 && computeCapability != 100) {
if (p->major != 9 && p->major != 10) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this line and ifdef in the kernel go out of sync, and you keep vec_size 8 for an arch that's excluded by ifdef, you'll get an empty kernel. Can you find a way to make sure this doesn't happen? Even

Copy link
Collaborator Author

@Aidyn-A Aidyn-A Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two things I can do simultaneously:

  1. Raise an error from the kernel if it attempts to call an empty part of kernel.
  2. Implement a custom linter that compares this line vs #if __CUDA_ARCH__ ... line above, so the numbers must match.

Or just add a big fat comment, so anyone who modifies this file must keep the arches in sync.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding device assert in the kernel should suffice

@Aidyn-A Aidyn-A added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 12, 2026
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Feb 12, 2026

Mac OS build failure is unrelated.

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: trunk / macos-py3-arm64 / build, trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx950.4), trunk / linux-jammy-cuda13.0-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants