Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Sep 28, 2023

Unlike half/bfloat16 casts, where entire model is cast to half-precision floats, only parts of the network can be in float8 and therefore performance of the casts is important.

Speedup casts by implementing non-dynamically castable variants using new refactored gpu_kernel_nocast template.

Mesaure performance using the following script:

import torch

def run_cast_bench(size=(10000, 10000), src_dtype=torch.float16, dtype=torch.float8_e5m2):
    x=torch.rand(size, device="cuda", requires_grad=False, dtype=src_dtype)
    z=torch.empty(size, device="cuda", dtype=dtype, requires_grad=False)
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        z.copy_(x)
    rc=prof.key_averages()
    print(f"Running bench for src_dtype={src_dtype} dst_dtype={dtype} cuda_time={rc[1].cuda_time}")

if __name__ == "__main__":
    for dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
        run_cast_bench(src_dtype=torch.half, dtype=dtype)
        run_cast_bench(src_dtype=torch.float, dtype=dtype)
        run_cast_bench(src_dtype=torch.bfloat16, dtype=dtype)

Below are before and after results:

Cast type After Before
fp32->e5m2 228 us 336 us
fp16->e5m2 150 us 323 us
bf16->e5m2 150 us 322 us
fp32->e4m3 227 us 331 us
fp16->e4m3 148 us 318 us
bf16->e4m3 149 us 318 us

Skip the optimizations on ROCm platform
TODO:

@malfet malfet requested review from drisspg and vkuzo September 28, 2023 22:41
@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Sep 28, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110251

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit 1b52733 with merge base 7e6cf04 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

return Float8_e5m2(value);
#endif
});
} else if (other_dtype == kHalf) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are fp64 casts supported?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good and the performance wins look great.

I see that

import torch
from transformer_nuggets.utils.tracing import LoggingMode

a = torch.rand(2, dtype=torch.float16, device='cuda')
with LoggingMode():
    a.to(torch.float8_e4m3fn)
    

prints
$1: f8e4[2] = aten._to_copy.default($0, dtype=torch.float8_e4m3fn)

would we hit this path using the .to function?

@drisspg
Copy link
Contributor

drisspg commented Oct 2, 2023

Ran python benchmarks/bench_linear_float8.py --use_ts -o benchmarks/data/float8_sweep_cast_update before and after this pr:

name shape ref_dtype pt_fp8_time_sec pt_fp8_time_sec_updated % Improvement
attn.wqkv (16384, 8192, 1280) torch.bfloat16 0.00416412 0.00390124 6.31303
attn.w0 (16384, 1024, 8192) torch.bfloat16 0.00361837 0.00335387 7.30976
ffn.w13 (16384, 8192, 7168) torch.bfloat16 0.0107657 0.0102925 4.39569
ffn.w2 (16384, 3584, 8192) torch.bfloat16 0.00660132 0.00621062 5.91851
attn.wqkv (16384, 8192, 1280) torch.float16 0.00423888 0.00398655 5.95275
attn.w0 (16384, 1024, 8192) torch.float16 0.00367196 0.0034114 7.09605
ffn.w13 (16384, 8192, 7168) torch.float16 0.0107916 0.010287 4.67511
ffn.w2 (16384, 3584, 8192) torch.float16 0.00662924 0.00625864 5.59038

Looks like this is producing very meaningful speedups!

Unlike half/bfloat16 casts, where entire model is cast to half-precision
floats, only parts of the network can be in float8 and therefore
performance of the casts is important.

Speedup casts by implementing non-dynamically castable variants using
new refactored `gpu_kernel_nocast` template.

Mesaure performance using the following script:
```python
import torch

def run_cast_bench(size=(10000, 10000), src_dtype=torch.float16, dtype=torch.float8_e5m2):
    x=torch.rand(size, device="cuda", requires_grad=False, dtype=src_dtype)
    z=torch.empty(size, device="cuda", dtype=dtype, requires_grad=False)
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        z.copy_(x)
    rc=prof.key_averages()
    print(f"Running bench for src_dtype={src_dtype} dst_dtype={dtype} cuda_time={rc[1].cuda_time}")

if __name__ == "__main__":
    for dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
        run_cast_bench(src_dtype=torch.half, dtype=dtype)
        run_cast_bench(src_dtype=torch.float, dtype=dtype)
        run_cast_bench(src_dtype=torch.bfloat16, dtype=dtype)
```

Below are before and after results:
|  Cast type | Before | After |
| ---------- | ------ | ----- |
| fp32->e5m2 | 228 us | 336 us|
| fp16->e5m2 | 150 us | 323 us|
| bf16->e5m2 | 150 us | 322 us|
| fp32->e4m3 | 227 us | 331 us|
| fp16->e4m3 | 148 us | 318 us|
| bf16->e4m3 | 149 us | 318 us|

TODO:
 - Investigate why `__nv_cvt` intrinsics defined in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__FP8__MISC.html end up being slower
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@malfet malfet force-pushed the malfet/speedup-float8-cuda-casts branch from ae5be0f to ccfbe73 Compare October 3, 2023 02:17
@malfet
Copy link
Contributor Author

malfet commented Oct 3, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/speedup-float8-cuda-casts branch December 18, 2023 19:04
pytorchmergebot pushed a commit that referenced this pull request Oct 29, 2024
Similar to #110251 we're seeing cases where vectorization can benefit casts to fp16/bf16

Pull Request resolved: #137053
Approved by: https://github.com/drisspg
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Mar 17, 2025
Similar to pytorch#110251 we're seeing cases where vectorization can benefit
casts to fp16/bf16

Pull Request resolved: pytorch#137053
Approved by: https://github.com/drisspg

Co-authored-by: eqy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants