Skip to content

Conversation

@lw
Copy link
Contributor

@lw lw commented Jul 9, 2025

Most of the work had already been done by @jeffdaily in #154680, but there was one remaining check that needed to be modified in order for torch._scaled_mm to use cuBLAS over CUTLASS when available.

I tested this change by rebuilding PyTorch locally with CUDA 12.9 and ran torch._scaled_mm under the profiler, and observed that the kernel being launched is called nvjet_qqtst_128x128_128x6_1x1_h_bz_coopA_algo2_ovscale_TNT (where ovscale stands for "outer vector scaling", I believe, which is how cuBLAS calls this scaling mode).

I then benchmarked the new kernels against the old CUTLASS ones on a standard 700W H100 GPU. I used the same approach as in #134781, and obtained these speed-ups:
image
image

We see that the two kernels perform very closely (I'm surprised, I would have expected cuBLAS to outperform CUTLASS across the board), with some thin/skewed shapes becoming worse but some very large shapes becoming better.

I guess the questions are whether we consider this a net-zero change (given that there's improvements and degradations), and how large we consider the burden of maintaining our own CUTLASS kernels.

Stack from ghstack (oldest at bottom):

cc @ptrblck @msaroufim @eqy @jerryzh168

[ghstack-poisoned]
@lw lw requested review from eqy and syed-ahmed as code owners July 9, 2025 09:50
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157905

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit de49ea3 with merge base 7caf6c8 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

lw added a commit that referenced this pull request Jul 9, 2025
ghstack-source-id: 80c0b6f
Pull Request resolved: #157905
@lw
Copy link
Contributor Author

lw commented Jul 9, 2025

Another thing to note is that our CUTLASS kernels have a persistent schedule without autobalancing, which was the reason we had to introduce the explicit SM carvout option. cuBLAS does have this auto-balancing, hence it should perform substantially better than CUTLASS whenever there's some comms in the background, and it would allow us to remove the carveout option.

@lw lw added the topic: not user facing topic category label Jul 9, 2025
@Skylion007
Copy link
Collaborator

Skylion007 commented Jul 9, 2025

Another thing to note is that our CUTLASS kernels have a persistent schedule without autobalancing, which was the reason we had to introduce the explicit SM carvout option. cuBLAS does have this auto-balancing, hence it should perform substantially better than CUTLASS whenever there's some comms in the background, and it would allow us to remove the carveout option.

Also puts the burden on Nvidia for maintaining this for new architectures like Blackwell instead of us.

@drisspg
Copy link
Contributor

drisspg commented Jul 9, 2025

Ohh very interesting, the biggest annoyance in my mind is that we will need to still maintain support for both until we have all builds >= 12.9 which I am not entirely sure will be.

Do you think its worth while adding some simple checks for >= 12.9 and large problem shapes at least until we can fully remove the old kernels?

Also one other curious question I have @eqy do we use https://docs.nvidia.com/cuda/cublas/#cublasltmatmulalgogetheuristic anywhere? Or have ways to enable autotuning for these?

@eqy
Copy link
Collaborator

eqy commented Jul 9, 2025

Ohh very interesting, the biggest annoyance in my mind is that we will need to still maintain support for both until we have all builds >= 12.9 which I am not entirely sure will be.

Do you think its worth while adding some simple checks for >= 12.9 and large problem shapes at least until we can fully remove the old kernels?

Also one other thing I have @eqy do we use https://docs.nvidia.com/cuda/cublas/#cublasltmatmulalgogetheuristic anywhere? Or have ways to enable autotuning for these?

We can consider autotuning, as that has been proposed in the past.

@lw
Copy link
Contributor Author

lw commented Jul 10, 2025

Do you think its worth while adding some simple checks for >= 12.9 and large problem shapes at least until we can fully remove the old kernels?

What do you mean? You want me to add tests? Won't the existing tests be enough to cover this, provided they are built/run with CUDA 12.9?

@lw
Copy link
Contributor Author

lw commented Jul 10, 2025

I tried to quantify the impact of auto-balancing in the persistent schedule, and the impact can be huge (~2x). I launched an 8k8k3k matmul with a torch.cuda._sleep running in a background stream, and the matmul took 3.347ms with cuBLAS and 6.926ms with CUTLASS.

Code to repro
import torch
background_stream = torch.cuda.Stream()
a = torch.randn((8192, 3072), dtype=torch.bfloat16, device="cuda", requires_grad=True)
b = torch.randn((8192, 3072), dtype=torch.bfloat16, device="cuda", requires_grad=True)
def scale_rowwise(t):
    scale = t.abs().amax(dim=-1, keepdim=True).float() / torch.finfo(torch.float8_e4m3fn).max
    t = t.div(scale).to(torch.float8_e4m3fn)
    return t, scale
a_fp8, scale_a = scale_rowwise(a)
b_fp8, scale_b = scale_rowwise(b)
torch._scaled_mm(a_fp8, b_fp8.t(), scale_a=scale_a, scale_b=scale_b.t(), out_dtype=torch.bfloat16, use_fast_accum=False)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
    with torch.cuda.stream(background_stream):
        torch.cuda._sleep(10000000000)
    for _ in range(10):
        torch._scaled_mm(a_fp8, b_fp8.t(), scale_a=scale_a, scale_b=scale_b.t(), out_dtype=torch.bfloat16, use_fast_accum=False)
    torch.cuda.synchronize()
prof.key_averages()

[ghstack-poisoned]
@drisspg
Copy link
Contributor

drisspg commented Jul 11, 2025

For provenance, spoke offline -> the sm auto balance alone makes this very valuable on cuda versions. Wanted to determine if there are any trends for performance delta especially in small M, N regime and if so maybe continue to dispatch to cutlass even on 12.9 +

@lw lw added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 11, 2025
@drisspg drisspg added module: cuda Related to torch.cuda, and CUDA support in general module: floatx (formerly float8) For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float types labels Jul 11, 2025
@lw
Copy link
Contributor Author

lw commented Jul 11, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@lw lw deleted the gh/lw/1/head branch July 25, 2025 13:29
lw added a commit that referenced this pull request Aug 6, 2025
After #157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter.


ghstack-source-id: aae99d5
Pull-Request: #159957
pytorchmergebot pushed a commit that referenced this pull request Aug 6, 2025
…es (#159957)

After #157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter.

Pull Request resolved: #159957
Approved by: https://github.com/vkuzo, https://github.com/danielvegamyhre
pytorchmergebot pushed a commit that referenced this pull request Aug 28, 2025
…on H100 (#161305)

Following #157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: #161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
pytorchmergebot pushed a commit that referenced this pull request Sep 5, 2025
…on H100 (#161305)

Following #157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: #161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
daisyden pushed a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…es (pytorch#159957)

After pytorch#157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter.

Pull Request resolved: pytorch#159957
Approved by: https://github.com/vkuzo, https://github.com/danielvegamyhre
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…on H100 (pytorch#161305)

Following pytorch#157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: pytorch#161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general module: floatx (formerly float8) For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float types topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants