-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Speed-up casts to FP8 #110251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed-up casts to FP8 #110251
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110251
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit 1b52733 with merge base 7e6cf04 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/native/cuda/Copy.cu
Outdated
| return Float8_e5m2(value); | ||
| #endif | ||
| }); | ||
| } else if (other_dtype == kHalf) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are fp64 casts supported?
drisspg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good and the performance wins look great.
I see that
import torch
from transformer_nuggets.utils.tracing import LoggingMode
a = torch.rand(2, dtype=torch.float16, device='cuda')
with LoggingMode():
a.to(torch.float8_e4m3fn)
prints
$1: f8e4[2] = aten._to_copy.default($0, dtype=torch.float8_e4m3fn)
would we hit this path using the .to function?
|
Ran
Looks like this is producing very meaningful speedups! |
Unlike half/bfloat16 casts, where entire model is cast to half-precision
floats, only parts of the network can be in float8 and therefore
performance of the casts is important.
Speedup casts by implementing non-dynamically castable variants using
new refactored `gpu_kernel_nocast` template.
Mesaure performance using the following script:
```python
import torch
def run_cast_bench(size=(10000, 10000), src_dtype=torch.float16, dtype=torch.float8_e5m2):
x=torch.rand(size, device="cuda", requires_grad=False, dtype=src_dtype)
z=torch.empty(size, device="cuda", dtype=dtype, requires_grad=False)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
z.copy_(x)
rc=prof.key_averages()
print(f"Running bench for src_dtype={src_dtype} dst_dtype={dtype} cuda_time={rc[1].cuda_time}")
if __name__ == "__main__":
for dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
run_cast_bench(src_dtype=torch.half, dtype=dtype)
run_cast_bench(src_dtype=torch.float, dtype=dtype)
run_cast_bench(src_dtype=torch.bfloat16, dtype=dtype)
```
Below are before and after results:
| Cast type | Before | After |
| ---------- | ------ | ----- |
| fp32->e5m2 | 228 us | 336 us|
| fp16->e5m2 | 150 us | 323 us|
| bf16->e5m2 | 150 us | 322 us|
| fp32->e4m3 | 227 us | 331 us|
| fp16->e4m3 | 148 us | 318 us|
| bf16->e4m3 | 149 us | 318 us|
TODO:
- Investigate why `__nv_cvt` intrinsics defined in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__FP8__MISC.html end up being slower
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
ae5be0f to
ccfbe73
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Similar to #110251 we're seeing cases where vectorization can benefit casts to fp16/bf16 Pull Request resolved: #137053 Approved by: https://github.com/drisspg
Similar to pytorch#110251 we're seeing cases where vectorization can benefit casts to fp16/bf16 Pull Request resolved: pytorch#137053 Approved by: https://github.com/drisspg
Similar to pytorch#110251 we're seeing cases where vectorization can benefit casts to fp16/bf16 Pull Request resolved: pytorch#137053 Approved by: https://github.com/drisspg Co-authored-by: eqy <[email protected]>
Unlike half/bfloat16 casts, where entire model is cast to half-precision floats, only parts of the network can be in float8 and therefore performance of the casts is important.
Speedup casts by implementing non-dynamically castable variants using new refactored
gpu_kernel_nocasttemplate.Mesaure performance using the following script:
Below are before and after results:
Skip the optimizations on ROCm platform
TODO:
__nv_cvtintrinsics defined in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__FP8__MISC.html end up being slower