-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ATen][CUDA] Implement 128 bit vectorization #141959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141959
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 9292e7c with merge base aa95618 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This unconditionally sets items per thread to 8. Is it neutral/beneficial on A100? |
| using cpp_type = typename function_traits<func_t>::result_type; | ||
| uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type)); | ||
| if(vec_size == 16) { | ||
| vec_size = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this vec_size set to 4 and not 8? Also, why is this conditional needed at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition I used to avoid vec8 for 1-byte data (bool and uint8) as some tests in test/test_binary_ufuncs.py were failing with mismatches.
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
eec3f3c to
b55c2fa
Compare
It improves performance on A100 as well. A100 benchmark results are here
|
| auto stream = at::cuda::getCurrentCUDAStream(); | ||
| using cpp_type = typename function_traits<func_t>::result_type; | ||
| const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data, sizeof(cpp_type) == 2); | ||
| // Here 1-byte dtypes are intentionally avoided as the vec8 causes incorrect outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry what? Why does vec8 cause incorrect outputs? It has to be fixed. some 1-byte ops already use 8 elements per thread, so why would 8-vectorization be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need more info on 8-vectorization causing incorrect results
|
Compiler issue looks real, I think for now after discussing with @Aidyn-A we'll gate the problematic cases via |
|
I will need to re-base and resolve conflicts with #143269 that apparently did the same thing. |
|
Thanks for digging in, ping me when PR is ready for review! |
This is a re-base PR to my previous one #141959. Description from the original PR: This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100. <details> <summary>The benchmark code used </summary> ```Python import time import torch from torch.profiler import profile, ProfilerActivity def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False): device = torch.device("cuda") shapes = [] for p in range(24, 30): shape = 1<<p shapes.append(shape) for shape in shapes: for _ in range(6): x = torch.randn(shape, device=device, dtype=dtype) y = function(x) if print_profile: x = torch.randn(shape, device=device, dtype=dtype) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: y = function(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) x = torch.randn(shape, device=device, dtype=dtype) torch.cuda.synchronize() t1 = time.perf_counter() for _ in range(6): y = function(x) torch.cuda.synchronize() t2 = time.perf_counter() perf_time = (t2 - t1) / 6 print(f"{function.__name__}, {dtype}, {shape}, {perf_time}") if check_numerics: x_cpu = x.cpu() y_cpu = function(x_cpu).cuda() try: torch.testing.assert_allclose(y_cpu, y) except AssertionError as error: print("An exception occurred:", error) def main(): ops = [ torch.relu, torch.sigmoid, torch.tanh, torch.nn.functional.gelu, torch.sin, torch.exp, ] dtypes = [ torch.float16, torch.bfloat16, torch.float32, ] for op in ops: for dtype in dtypes: benchmark(op, dtype=dtype) torch.cuda.empty_cache() if __name__ == "__main__": main() ``` </details> <details> <summary> Results </summary> | op | dtype | size | time after | time before | % improvement | | ---- | ---- | ---- | ---- | ---- | ---- | | relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 | | relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 | | relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 | | relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 | | relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 | | relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 | | relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 | | relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 | | relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 | | relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 | | relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 | | relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 | | relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 | | relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 | | relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 | | relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 | | relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 | | sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 | | sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 | | sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 | | sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 | | sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 | | sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 | | sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 | | sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 | | sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 | | sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 | | sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 | | sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 | | sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 | | sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 | | sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 | | sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 | | sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 | | sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 | | tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 | | tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 | | tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 | | tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 | | tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 | | tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 | | tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 | | tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 | | tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 | | tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 | | tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 | | tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 | | tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 | | tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 | | tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 | | tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 | | tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 | | tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 | | gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 | | gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 | | gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 | | gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 | | gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 | | gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 | | gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 | | gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 | | gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 | | gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 | | gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 | | gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 | | gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 | | gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 | | gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 | | gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 | | gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 | | gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 | | sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 | | sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 | | sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 | | sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 | | sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 | | sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 | | sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 | | sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 | | sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 | | sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 | | sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 | | sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 | | sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 | | sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 | | sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 | | sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 | | sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 | | sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 | | exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 | | exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 | | exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 | | exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 | | exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 | | exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 | | exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 | | exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 | | exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 | | exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 | | exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 | | exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 | | exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 | | exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 | | exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 | | exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 | | exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 | | exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 | </details> Pull Request resolved: #145746 Approved by: https://github.com/eqy, https://github.com/ngimel
|
Closing in favor of #145746. |
This is a re-base PR to my previous one #141959. Description from the original PR: This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100. <details> <summary>The benchmark code used </summary> ```Python import time import torch from torch.profiler import profile, ProfilerActivity def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False): device = torch.device("cuda") shapes = [] for p in range(24, 30): shape = 1<<p shapes.append(shape) for shape in shapes: for _ in range(6): x = torch.randn(shape, device=device, dtype=dtype) y = function(x) if print_profile: x = torch.randn(shape, device=device, dtype=dtype) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: y = function(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) x = torch.randn(shape, device=device, dtype=dtype) torch.cuda.synchronize() t1 = time.perf_counter() for _ in range(6): y = function(x) torch.cuda.synchronize() t2 = time.perf_counter() perf_time = (t2 - t1) / 6 print(f"{function.__name__}, {dtype}, {shape}, {perf_time}") if check_numerics: x_cpu = x.cpu() y_cpu = function(x_cpu).cuda() try: torch.testing.assert_allclose(y_cpu, y) except AssertionError as error: print("An exception occurred:", error) def main(): ops = [ torch.relu, torch.sigmoid, torch.tanh, torch.nn.functional.gelu, torch.sin, torch.exp, ] dtypes = [ torch.float16, torch.bfloat16, torch.float32, ] for op in ops: for dtype in dtypes: benchmark(op, dtype=dtype) torch.cuda.empty_cache() if __name__ == "__main__": main() ``` </details> <details> <summary> Results </summary> | op | dtype | size | time after | time before | % improvement | | ---- | ---- | ---- | ---- | ---- | ---- | | relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 | | relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 | | relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 | | relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 | | relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 | | relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 | | relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 | | relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 | | relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 | | relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 | | relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 | | relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 | | relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 | | relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 | | relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 | | relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 | | relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 | | sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 | | sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 | | sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 | | sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 | | sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 | | sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 | | sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 | | sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 | | sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 | | sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 | | sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 | | sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 | | sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 | | sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 | | sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 | | sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 | | sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 | | sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 | | tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 | | tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 | | tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 | | tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 | | tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 | | tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 | | tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 | | tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 | | tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 | | tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 | | tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 | | tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 | | tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 | | tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 | | tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 | | tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 | | tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 | | tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 | | gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 | | gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 | | gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 | | gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 | | gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 | | gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 | | gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 | | gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 | | gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 | | gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 | | gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 | | gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 | | gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 | | gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 | | gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 | | gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 | | gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 | | gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 | | sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 | | sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 | | sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 | | sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 | | sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 | | sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 | | sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 | | sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 | | sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 | | sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 | | sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 | | sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 | | sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 | | sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 | | sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 | | sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 | | sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 | | sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 | | exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 | | exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 | | exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 | | exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 | | exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 | | exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 | | exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 | | exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 | | exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 | | exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 | | exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 | | exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 | | exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 | | exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 | | exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 | | exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 | | exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 | | exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 | </details> Pull Request resolved: #145746 Approved by: https://github.com/eqy, https://github.com/ngimel Co-authored-by: Aaron Gokaslan <[email protected]>
This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.
The benchmark code used
Results
cc @msaroufim @ptrblck @eqy @manuelcandales @SherlockNoMad @angelayi