-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Use new cuBLAS row-wise fp8 matmul for scaled-mm #157905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157905
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit de49ea3 with merge base 7caf6c8 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Another thing to note is that our CUTLASS kernels have a persistent schedule without autobalancing, which was the reason we had to introduce the explicit SM carvout option. cuBLAS does have this auto-balancing, hence it should perform substantially better than CUTLASS whenever there's some comms in the background, and it would allow us to remove the carveout option. |
Also puts the burden on Nvidia for maintaining this for new architectures like Blackwell instead of us. |
|
Ohh very interesting, the biggest annoyance in my mind is that we will need to still maintain support for both until we have all builds >= 12.9 which I am not entirely sure will be. Do you think its worth while adding some simple checks for >= 12.9 and large problem shapes at least until we can fully remove the old kernels? Also one other curious question I have @eqy do we use |
We can consider autotuning, as that has been proposed in the past. |
What do you mean? You want me to add tests? Won't the existing tests be enough to cover this, provided they are built/run with CUDA 12.9? |
|
I tried to quantify the impact of auto-balancing in the persistent schedule, and the impact can be huge (~2x). I launched an 8k8k3k matmul with a torch.cuda._sleep running in a background stream, and the matmul took 3.347ms with cuBLAS and 6.926ms with CUTLASS. Code to reproimport torch
background_stream = torch.cuda.Stream()
a = torch.randn((8192, 3072), dtype=torch.bfloat16, device="cuda", requires_grad=True)
b = torch.randn((8192, 3072), dtype=torch.bfloat16, device="cuda", requires_grad=True)
def scale_rowwise(t):
scale = t.abs().amax(dim=-1, keepdim=True).float() / torch.finfo(torch.float8_e4m3fn).max
t = t.div(scale).to(torch.float8_e4m3fn)
return t, scale
a_fp8, scale_a = scale_rowwise(a)
b_fp8, scale_b = scale_rowwise(b)
torch._scaled_mm(a_fp8, b_fp8.t(), scale_a=scale_a, scale_b=scale_b.t(), out_dtype=torch.bfloat16, use_fast_accum=False)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
with torch.cuda.stream(background_stream):
torch.cuda._sleep(10000000000)
for _ in range(10):
torch._scaled_mm(a_fp8, b_fp8.t(), scale_a=scale_a, scale_b=scale_b.t(), out_dtype=torch.bfloat16, use_fast_accum=False)
torch.cuda.synchronize()
prof.key_averages() |
|
For provenance, spoke offline -> the sm auto balance alone makes this very valuable on cuda versions. Wanted to determine if there are any trends for performance delta especially in small M, N regime and if so maybe continue to dispatch to cutlass even on 12.9 + |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
After #157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter. ghstack-source-id: aae99d5 Pull-Request: #159957
…es (#159957) After #157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter. Pull Request resolved: #159957 Approved by: https://github.com/vkuzo, https://github.com/danielvegamyhre
…on H100 (#161305) Following #157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: #161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (#161305) Following #157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: #161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…es (pytorch#159957) After pytorch#157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter. Pull Request resolved: pytorch#159957 Approved by: https://github.com/vkuzo, https://github.com/danielvegamyhre
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
…on H100 (pytorch#161305) Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
Most of the work had already been done by @jeffdaily in #154680, but there was one remaining check that needed to be modified in order for
torch._scaled_mmto use cuBLAS over CUTLASS when available.I tested this change by rebuilding PyTorch locally with CUDA 12.9 and ran
torch._scaled_mmunder the profiler, and observed that the kernel being launched is callednvjet_qqtst_128x128_128x6_1x1_h_bz_coopA_algo2_ovscale_TNT(whereovscalestands for "outer vector scaling", I believe, which is how cuBLAS calls this scaling mode).I then benchmarked the new kernels against the old CUTLASS ones on a standard 700W H100 GPU. I used the same approach as in #134781, and obtained these speed-ups:


We see that the two kernels perform very closely (I'm surprised, I would have expected cuBLAS to outperform CUTLASS across the board), with some thin/skewed shapes becoming worse but some very large shapes becoming better.
I guess the questions are whether we consider this a net-zero change (given that there's improvements and degradations), and how large we consider the burden of maintaining our own CUTLASS kernels.
Stack from ghstack (oldest at bottom):
cc @ptrblck @msaroufim @eqy @jerryzh168