-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ROCm] Optimize layer norm backward kernel for ROCm #87635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replaced the use of GammaBetaBackwardCUDAKernel with the more performant cuComputePartGradGammaBeta and cuComputeGradGammaBeta from Apex
Replaced the use of layer_norm_grad_input with he more performant cuComputeGradInputfrom Apex
ENABLE_APEX_GRADINPUT: - layer_norm_grad_input_kernel -> cuComputeGradInput ENABLE_APEX_GAMMABETA: - GammaBetaBackwardCUDAKernel -> cuComputePartGradGammaBeta, cuComputeGradGammaBeta
…keep only ENABLE_APEX_GAMMABETA code changes
…comp" This reverts commit 4a75ab0.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87635
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0e32924: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@ngimel and @zasdfgbnm could you please help review this PR? |
|
@jeffdaily can you send this to the right folks I'm not so hot about adding so much ROCm only CUDA code, it will be maintenance trouble, but if the ROCm team agrees to maintain it we can add it. |
One of the right folks is the PR author, or our other team member @jataylo. ROCm team of course agrees to maintain this code. |
|
I invited @jataylo to the repo, once he accepts the invite please add him as reviewer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ezyang. The reported performance gains from the benchmark have been replicated locally.
|
Hi @ezyang just repinging this PR if you had any more comments. |
…or ROCm (#87726) We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of #68238 (comment) on AMD GPUs. This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: #87635 We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU. **At [the previous PR](#87635 <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148 50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761 200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284 802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946 200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449 1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758 6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989 6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615 200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208 1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741 6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078 6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514 200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683 1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598 6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035 6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871 200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514 1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264 6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609 6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997 200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138 1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464 6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338 6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018 200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845 1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823 6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719 6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012 128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848 256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678 512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804 1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699 2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708 4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617 8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643 16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291 32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423 </body> </html> ---- **At this PR:** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {background:yellow; mso-pattern:black none;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222 50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399 200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265 802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331 200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548 1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851 6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927 6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377 200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894 1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564 6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782 6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506 200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019 1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851 6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526 6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926 200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695 1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573 6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814 6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441 200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548 1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452 6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915 6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896 200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869 1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646 6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036 6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039 128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851 256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554 512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784 1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679 2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603 4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326 8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669 16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022 32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386 </body> </html> --- **Performance Improvement (%)** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {mso-number-format:"0\.000";} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 9.147 | 12.094 50176 | 384 | 4.710 | 8.230 200704 | 192 | 10.655 | 8.266 802816 | 64 | 14.000 | 9.042 200 | 256 | 1.263 | 2.542 1000 | 256 | 3.109 | -0.246 6000 | 256 | 1.183 | 2.796 6272 | 256 | -3.579 | -3.394 200 | 512 | -5.489 | -7.852 1000 | 512 | -3.270 | -2.240 6000 | 512 | -3.456 | -4.596 6272 | 512 | -0.392 | -2.644 200 | 1024 | -7.862 | -0.969 1000 | 1024 | -1.321 | 0.359 6000 | 1024 | -0.693 | -2.336 6272 | 1024 | -2.130 | -2.034 200 | 1536 | -5.287 | -5.151 1000 | 1536 | -0.683 | -0.829 6000 | 1536 | 2.792 | 6.989 6272 | 1536 | 0.051 | 2.132 200 | 2048 | -5.461 | -1.167 1000 | 2048 | 4.126 | 2.701 6000 | 2048 | -0.797 | 0.453 6272 | 2048 | 2.792 | 0.126 200 | 3072 | 0.024 | -0.063 1000 | 3072 | 1.820 | 2.956 6000 | 3072 | 2.531 | 0.275 6272 | 3072 | 1.054 | 7.929 128 | 2097152 | 2.564 | 0.963 256 | 1048576 | 0.077 | 0.582 512 | 524288 | 0.428 | 0.094 1024 | 262144 | 0.581 | -0.096 2048 | 131072 | -0.225 | 0.888 4096 | 65536 | 0.527 | 0.428 8192 | 32768 | 0.204 | 0.717 16384 | 16384 | -0.216 | -0.492 32768 | 8192 | 0.786 | 5.127 </body> </html> CC: @jeffdaily Pull Request resolved: #87726 Approved by: https://github.com/ngimel
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - pytorch#26201 - pytorch#27634 - pytorch#68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 </body> </html> --------- **At this PR** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl63 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 </body> </html> --------- **Performance Improvement (%)** <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491 </div> <!--EndFragment--> </body> </html> --------- **Benchmark script of this PR** ``` from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: pytorch#87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
…or ROCm (pytorch#87726) We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs. This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635 We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU. **At [the previous PR](pytorch#87635 <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148 50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761 200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284 802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946 200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449 1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758 6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989 6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615 200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208 1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741 6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078 6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514 200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683 1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598 6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035 6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871 200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514 1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264 6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609 6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997 200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138 1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464 6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338 6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018 200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845 1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823 6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719 6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012 128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848 256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678 512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804 1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699 2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708 4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617 8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643 16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291 32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423 </body> </html> ---- **At this PR:** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {background:yellow; mso-pattern:black none;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222 50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399 200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265 802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331 200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548 1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851 6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927 6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377 200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894 1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564 6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782 6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506 200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019 1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851 6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526 6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926 200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695 1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573 6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814 6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441 200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548 1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452 6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915 6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896 200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869 1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646 6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036 6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039 128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851 256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554 512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784 1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679 2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603 4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326 8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669 16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022 32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386 </body> </html> --- **Performance Improvement (%)** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {mso-number-format:"0\.000";} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 9.147 | 12.094 50176 | 384 | 4.710 | 8.230 200704 | 192 | 10.655 | 8.266 802816 | 64 | 14.000 | 9.042 200 | 256 | 1.263 | 2.542 1000 | 256 | 3.109 | -0.246 6000 | 256 | 1.183 | 2.796 6272 | 256 | -3.579 | -3.394 200 | 512 | -5.489 | -7.852 1000 | 512 | -3.270 | -2.240 6000 | 512 | -3.456 | -4.596 6272 | 512 | -0.392 | -2.644 200 | 1024 | -7.862 | -0.969 1000 | 1024 | -1.321 | 0.359 6000 | 1024 | -0.693 | -2.336 6272 | 1024 | -2.130 | -2.034 200 | 1536 | -5.287 | -5.151 1000 | 1536 | -0.683 | -0.829 6000 | 1536 | 2.792 | 6.989 6272 | 1536 | 0.051 | 2.132 200 | 2048 | -5.461 | -1.167 1000 | 2048 | 4.126 | 2.701 6000 | 2048 | -0.797 | 0.453 6272 | 2048 | 2.792 | 0.126 200 | 3072 | 0.024 | -0.063 1000 | 3072 | 1.820 | 2.956 6000 | 3072 | 2.531 | 0.275 6272 | 3072 | 1.054 | 7.929 128 | 2097152 | 2.564 | 0.963 256 | 1048576 | 0.077 | 0.582 512 | 524288 | 0.428 | 0.094 1024 | 262144 | 0.581 | -0.096 2048 | 131072 | -0.225 | 0.888 4096 | 65536 | 0.527 | 0.428 8192 | 32768 | 0.204 | 0.717 16384 | 16384 | -0.216 | -0.492 32768 | 8192 | 0.786 | 5.127 </body> </html> CC: @jeffdaily Pull Request resolved: pytorch#87726 Approved by: https://github.com/ngimel
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - pytorch#26201 - pytorch#27634 - pytorch#68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 </body> </html> --------- **At this PR** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl63 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 </body> </html> --------- **Performance Improvement (%)** <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491 </div> <!--EndFragment--> </body> </html> --------- **Benchmark script of this PR** ``` # Ref: # 1. pytorch#26201 # 2. pytorch#68238 from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") # CVT config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] # pytorch#68238 (comment) config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] # pytorch#27634 config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: pytorch#87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
…or ROCm (pytorch#87726) We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs. This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635 We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU. **At [the previous PR](pytorch#87635 <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148 50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761 200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284 802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946 200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449 1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758 6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989 6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615 200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208 1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741 6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078 6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514 200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683 1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598 6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035 6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871 200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514 1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264 6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609 6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997 200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138 1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464 6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338 6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018 200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845 1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823 6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719 6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012 128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848 256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678 512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804 1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699 2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708 4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617 8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643 16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291 32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423 </body> </html> ---- **At this PR:** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {background:yellow; mso-pattern:black none;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222 50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399 200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265 802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331 200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548 1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851 6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927 6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377 200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894 1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564 6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782 6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506 200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019 1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851 6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526 6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926 200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695 1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573 6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814 6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441 200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548 1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452 6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915 6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896 200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869 1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646 6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036 6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039 128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851 256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554 512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784 1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679 2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603 4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326 8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669 16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022 32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386 </body> </html> --- **Performance Improvement (%)** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {mso-number-format:"0\.000";} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 9.147 | 12.094 50176 | 384 | 4.710 | 8.230 200704 | 192 | 10.655 | 8.266 802816 | 64 | 14.000 | 9.042 200 | 256 | 1.263 | 2.542 1000 | 256 | 3.109 | -0.246 6000 | 256 | 1.183 | 2.796 6272 | 256 | -3.579 | -3.394 200 | 512 | -5.489 | -7.852 1000 | 512 | -3.270 | -2.240 6000 | 512 | -3.456 | -4.596 6272 | 512 | -0.392 | -2.644 200 | 1024 | -7.862 | -0.969 1000 | 1024 | -1.321 | 0.359 6000 | 1024 | -0.693 | -2.336 6272 | 1024 | -2.130 | -2.034 200 | 1536 | -5.287 | -5.151 1000 | 1536 | -0.683 | -0.829 6000 | 1536 | 2.792 | 6.989 6272 | 1536 | 0.051 | 2.132 200 | 2048 | -5.461 | -1.167 1000 | 2048 | 4.126 | 2.701 6000 | 2048 | -0.797 | 0.453 6272 | 2048 | 2.792 | 0.126 200 | 3072 | 0.024 | -0.063 1000 | 3072 | 1.820 | 2.956 6000 | 3072 | 2.531 | 0.275 6272 | 3072 | 1.054 | 7.929 128 | 2097152 | 2.564 | 0.963 256 | 1048576 | 0.077 | 0.582 512 | 524288 | 0.428 | 0.094 1024 | 262144 | 0.581 | -0.096 2048 | 131072 | -0.225 | 0.888 4096 | 65536 | 0.527 | 0.428 8192 | 32768 | 0.204 | 0.717 16384 | 16384 | -0.216 | -0.492 32768 | 8192 | 0.786 | 5.127 </body> </html> CC: @jeffdaily Pull Request resolved: pytorch#87726 Approved by: https://github.com/ngimel
It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: #140259 Approved by: https://github.com/jianyuh
…ch#140259) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
…ch#140259) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
…ch#140259) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
…ch#140259) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
…ch#140259) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh
… kernel (pytorch#140259) (#1766) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
… kernel (pytorch#140259) (#1767) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736) Fixes #ISSUE_NUMBER
… kernel (pytorch#140259) (#1767) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736) Fixes #ISSUE_NUMBER
… kernel (pytorch#140259) (#1766) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
… kernel (pytorch#140259) (#1766) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
… kernel (pytorch#140259) (#1766) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736)
… kernel (pytorch#140259) (#1767) It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation. On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory. In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd. Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV. Pull Request resolved: pytorch#140259 Approved by: https://github.com/jianyuh (cherry picked from commit 001f736) Fixes #ISSUE_NUMBER
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when
fs(=config_min our benchmark script) is large andbs(=config_nin our benchmark script) is small (commonly seen in the CvT model) in the benchmark script of PR #68238 on AMD GPUs.This PR is to replace
GammaBetaBackwardCUDAKernelwith the Apex layernorm backward kernel with some ROCm-specific parameter tuning whenfs(=config_m) is larger than 512 on AMD GPUs.There are a few PRs for LayerNorm kernel:
Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.
Current
At this PR
Performance Improvement (%)
Benchmark script of this PR
CC: @jeffdaily
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport