Skip to content

Conversation

@hubertlu-tw
Copy link
Contributor

@hubertlu-tw hubertlu-tw commented Oct 24, 2022

We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in the CvT model) in the benchmark script of PR #68238 on AMD GPUs.

This PR is to replace GammaBetaBackwardCUDAKernel with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.


Current

M N fwd (half) fwdbwd (half) fwd (float) fwdbwd (float)
50432 384 0.387256 1.372758 0.378975 1.47892
50176 384 0.38231 1.362416 0.378084 1.473886
200704 192 0.997859 4.315875 0.989306 4.560827
802816 64 3.671828 16.68013 3.613515 16.827946
200 256 0.066503 0.332096 0.071422 0.325349
1000 256 0.071848 0.333355 0.073038 0.334753
6000 256 0.086334 0.345139 0.086834 0.347429
6272 256 0.088601 0.347906 0.087855 0.351245
200 512 0.071626 0.329726 0.073798 0.326878
1000 512 0.073975 0.330226 0.074166 0.332751
6000 512 0.099617 0.362367 0.100095 0.378313
6272 512 0.100378 0.358066 0.099857 0.395982
200 1024 0.072954 0.326382 0.073899 0.333007
1000 1024 0.0743 0.325532 0.071126 0.330991
6000 1024 0.127025 0.390084 0.128692 0.471504
6272 1024 0.130704 0.403536 0.135244 0.487133
200 1536 0.070331 0.339169 0.070086 0.331015
1000 1536 0.075085 0.330042 0.076295 0.328778
6000 1536 0.148889 0.44949 0.155781 0.659987
6272 1536 0.154939 0.478871 0.17673 0.716025
200 2048 0.070269 0.335585 0.072804 0.334655
1000 2048 0.080094 0.326991 0.080426 0.32685
6000 2048 0.187888 0.623023 0.245762 0.981635
6272 2048 0.195431 0.65244 0.262574 1.008141
200 3072 0.068205 0.339428 0.073068 0.344034
1000 3072 0.087554 0.328899 0.09218 0.346433
6000 3072 0.240352 0.905058 0.368135 1.280462
6272 3072 0.26179 0.959387 0.387782 1.476524
128 2097152 5.905976 22.724793 10.287974 30.242092
256 1048576 4.561596 19.554308 10.223171 29.42371
512 524288 4.146751 22.7247 11.404285 39.175902
1024 262144 5.193135 23.403325 11.334512 38.947192
2048 131072 4.992907 23.377801 11.400286 40.889191
4096 65536 5.429488 24.275701 11.196778 41.4751
8192 32768 5.35758 21.360312 10.535418 42.875646
16384 16384 5.44947 20.852605 10.357685 34.603408
32768 8192 4.688925 17.379392 9.635596 31.188271

At this PR

M N fwd (half) fwdbwd (half) fwd (float) fwdbwd (float)
50432 384 0.38797 0.93103 0.37966 1.15283
50176 384 0.3874 0.96417 0.38462 1.18595
200704 192 1.00002 2.40876 0.99224 2.55579
802816 64 3.67348 7.98658 3.61871 7.72404
200 256 0.07292 0.35119 0.07195 0.32602
1000 256 0.07354 0.33325 0.07237 0.33742
6000 256 0.08819 0.33283 0.08453 0.3279
6272 256 0.0886 0.33446 0.08774 0.33426
200 512 0.0701 0.33505 0.07072 0.33018
1000 512 0.07042 0.33442 0.074 0.33206
6000 512 0.09931 0.34956 0.09895 0.3572
6272 512 0.10103 0.32976 0.10041 0.36635
200 1024 0.07144 0.33579 0.07209 0.33216
1000 1024 0.0736 0.32803 0.07286 0.32936
6000 1024 0.12584 0.38916 0.12852 0.48273
6272 1024 0.13053 0.38804 0.13464 0.49545
200 1536 0.07159 0.3396 0.07062 0.33545
1000 1536 0.07443 0.33239 0.07366 0.33204
6000 1536 0.14959 0.45043 0.15826 0.69119
6272 1536 0.1542 0.47644 0.18249 0.72208
200 2048 0.07258 0.33982 0.07412 0.33859
1000 2048 0.0793 0.32816 0.07864 0.32583
6000 2048 0.18973 0.571 0.25506 0.91796
6272 2048 0.19719 0.64208 0.26445 0.95055
200 3072 0.07092 0.33867 0.07104 0.34695
1000 3072 0.08727 0.33144 0.09144 0.36633
6000 3072 0.24683 0.87275 0.37761 1.3289
6272 3072 0.26437 0.91178 0.38496 1.53694
128 2097152 6.27936 23.69425 10.40004 30.13699
256 1048576 4.5404 19.47675 10.28494 29.36936
512 524288 4.13951 18.78771 10.09557 32.67083
1024 262144 4.47576 18.00411 9.56488 31.47117
2048 131072 4.28026 16.95619 9.40297 30.82845
4096 65536 4.2653 16.5018 9.03315 30.08392
8192 32768 4.25613 16.13583 8.9258 30.75296
16384 16384 4.20256 16.38207 9.52587 31.31113
32768 8192 4.20231 16.19452 9.31478 31.03514

Performance Improvement (%)

M N fwdbwd, torch.float16 fwdbwd, torch.float32
50432 384 32.178 22.049
50176 384 29.231 19.536
200704 192 44.188 43.962
802816 64 52.119 54.100
200 256 -5.750 -0.206
1000 256 0.031 -0.797
6000 256 3.566 5.621
6272 256 3.865 4.836
200 512 -1.615 -1.010
1000 512 -1.270 0.208
6000 512 3.534 5.581
6272 512 7.905 7.483
200 1024 -2.883 0.254
1000 1024 -0.767 0.493
6000 1024 0.237 -2.381
6272 1024 3.840 -1.707
200 1536 -0.127 -1.340
1000 1536 -0.711 -0.992
6000 1536 -0.209 -4.728
6272 1536 0.508 -0.846
200 2048 -1.262 -1.176
1000 2048 -0.358 0.312
6000 2048 8.350 6.487
6272 2048 1.588 5.713
200 3072 0.223 -0.848
1000 3072 -0.773 -5.743
6000 3072 3.570 -3.783
6272 3072 4.962 -4.092
128 2097152 -4.266 0.348
256 1048576 0.397 0.185
512 524288 17.325 16.605
1024 262144 23.070 19.195
2048 131072 27.469 24.605
4096 65536 32.023 27.465
8192 32768 24.459 28.274
16384 16384 21.439 9.514
32768 8192 6.818 0.491

Benchmark script of this PR

# Ref: 
#       1. https://github.com/pytorch/pytorch/pull/26201
#       2. https://github.com/pytorch/pytorch/pull/68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()


def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()
            
            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)
        
        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# https://github.com/pytorch/pytorch/pull/68238#issue-1051621716
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# https://github.com/pytorch/pytorch/pull/27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634


benchmark(config_m, config_n)

CC: @jeffdaily

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport

jataylo and others added 11 commits October 24, 2022 19:43
Replaced the use of GammaBetaBackwardCUDAKernel with the more performant cuComputePartGradGammaBeta and cuComputeGradGammaBeta from Apex
Replaced the use of layer_norm_grad_input with he more performant cuComputeGradInputfrom Apex
ENABLE_APEX_GRADINPUT:
- layer_norm_grad_input_kernel -> cuComputeGradInput

ENABLE_APEX_GAMMABETA:
- GammaBetaBackwardCUDAKernel -> cuComputePartGradGammaBeta, cuComputeGradGammaBeta
…keep only ENABLE_APEX_GAMMABETA code changes
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87635

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0e32924:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 24, 2022
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 24, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

@hubertlu-tw hubertlu-tw changed the title Layernorm bwd optim cvt [ROCm] Optimize layer norm backward kernel for ROCm Oct 24, 2022
@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Oct 24, 2022
@jeffdaily jeffdaily added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2022
@hubertlu-tw
Copy link
Contributor Author

@ngimel and @zasdfgbnm could you please help review this PR?

@ezyang ezyang requested a review from jeffdaily October 31, 2022 21:05
@ezyang
Copy link
Contributor

ezyang commented Oct 31, 2022

@jeffdaily can you send this to the right folks

I'm not so hot about adding so much ROCm only CUDA code, it will be maintenance trouble, but if the ROCm team agrees to maintain it we can add it.

@jithunnair-amd jithunnair-amd requested a review from ngimel November 1, 2022 16:19
@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 2, 2022
@jeffdaily
Copy link
Collaborator

@jeffdaily can you send this to the right folks

I'm not so hot about adding so much ROCm only CUDA code, it will be maintenance trouble, but if the ROCm team agrees to maintain it we can add it.

One of the right folks is the PR author, or our other team member @jataylo. ROCm team of course agrees to maintain this code.

@ezyang
Copy link
Contributor

ezyang commented Nov 4, 2022

I invited @jataylo to the repo, once he accepts the invite please add him as reviewer.

Copy link
Collaborator

@jataylo jataylo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ezyang. The reported performance gains from the benchmark have been replicated locally.

@jataylo
Copy link
Collaborator

jataylo commented Nov 11, 2022

Hi @ezyang just repinging this PR if you had any more comments.

pytorchmergebot pushed a commit that referenced this pull request Nov 28, 2022
…or ROCm (#87726)

We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of #68238 (comment) on AMD GPUs.

This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: #87635

We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU.

**At [the previous PR](#87635
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148
50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761
200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284
802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946
200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449
1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758
6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989
6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615
200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208
1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741
6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078
6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514
200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683
1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598
6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035
6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871
200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514
1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264
6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609
6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997
200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138
1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464
6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338
6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018
200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845
1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823
6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719
6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012
128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848
256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678
512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804
1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699
2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708
4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617
8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643
16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291
32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423

</body>

</html>

----

**At this PR:**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{background:yellow;
	mso-pattern:black none;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222
50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399
200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265
802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331
200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548
1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851
6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927
6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377
200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894
1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564
6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782
6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506
200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019
1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851
6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526
6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926
200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695
1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573
6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814
6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441
200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548
1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452
6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915
6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896
200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869
1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646
6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036
6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039
128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851
256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554
512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784
1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679
2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603
4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326
8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669
16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022
32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386

</body>

</html>

---

**Performance Improvement (%)**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{mso-number-format:"0\.000";}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 9.147 | 12.094
50176 | 384 | 4.710 | 8.230
200704 | 192 | 10.655 | 8.266
802816 | 64 | 14.000 | 9.042
200 | 256 | 1.263 | 2.542
1000 | 256 | 3.109 | -0.246
6000 | 256 | 1.183 | 2.796
6272 | 256 | -3.579 | -3.394
200 | 512 | -5.489 | -7.852
1000 | 512 | -3.270 | -2.240
6000 | 512 | -3.456 | -4.596
6272 | 512 | -0.392 | -2.644
200 | 1024 | -7.862 | -0.969
1000 | 1024 | -1.321 | 0.359
6000 | 1024 | -0.693 | -2.336
6272 | 1024 | -2.130 | -2.034
200 | 1536 | -5.287 | -5.151
1000 | 1536 | -0.683 | -0.829
6000 | 1536 | 2.792 | 6.989
6272 | 1536 | 0.051 | 2.132
200 | 2048 | -5.461 | -1.167
1000 | 2048 | 4.126 | 2.701
6000 | 2048 | -0.797 | 0.453
6272 | 2048 | 2.792 | 0.126
200 | 3072 | 0.024 | -0.063
1000 | 3072 | 1.820 | 2.956
6000 | 3072 | 2.531 | 0.275
6272 | 3072 | 1.054 | 7.929
128 | 2097152 | 2.564 | 0.963
256 | 1048576 | 0.077 | 0.582
512 | 524288 | 0.428 | 0.094
1024 | 262144 | 0.581 | -0.096
2048 | 131072 | -0.225 | 0.888
4096 | 65536 | 0.527 | 0.428
8192 | 32768 | 0.204 | 0.717
16384 | 16384 | -0.216 | -0.492
32768 | 8192 | 0.786 | 5.127

</body>

</html>

CC: @jeffdaily

Pull Request resolved: #87726
Approved by: https://github.com/ngimel
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 1, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 1, 2022
…or ROCm (pytorch#87726)

We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs.

This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635

We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU.

**At [the previous PR](pytorch#87635
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148
50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761
200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284
802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946
200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449
1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758
6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989
6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615
200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208
1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741
6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078
6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514
200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683
1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598
6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035
6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871
200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514
1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264
6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609
6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997
200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138
1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464
6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338
6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018
200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845
1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823
6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719
6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012
128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848
256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678
512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804
1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699
2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708
4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617
8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643
16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291
32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423

</body>

</html>

----

**At this PR:**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{background:yellow;
	mso-pattern:black none;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222
50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399
200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265
802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331
200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548
1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851
6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927
6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377
200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894
1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564
6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782
6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506
200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019
1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851
6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526
6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926
200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695
1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573
6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814
6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441
200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548
1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452
6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915
6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896
200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869
1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646
6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036
6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039
128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851
256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554
512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784
1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679
2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603
4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326
8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669
16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022
32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386

</body>

</html>

---

**Performance Improvement (%)**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{mso-number-format:"0\.000";}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 9.147 | 12.094
50176 | 384 | 4.710 | 8.230
200704 | 192 | 10.655 | 8.266
802816 | 64 | 14.000 | 9.042
200 | 256 | 1.263 | 2.542
1000 | 256 | 3.109 | -0.246
6000 | 256 | 1.183 | 2.796
6272 | 256 | -3.579 | -3.394
200 | 512 | -5.489 | -7.852
1000 | 512 | -3.270 | -2.240
6000 | 512 | -3.456 | -4.596
6272 | 512 | -0.392 | -2.644
200 | 1024 | -7.862 | -0.969
1000 | 1024 | -1.321 | 0.359
6000 | 1024 | -0.693 | -2.336
6272 | 1024 | -2.130 | -2.034
200 | 1536 | -5.287 | -5.151
1000 | 1536 | -0.683 | -0.829
6000 | 1536 | 2.792 | 6.989
6272 | 1536 | 0.051 | 2.132
200 | 2048 | -5.461 | -1.167
1000 | 2048 | 4.126 | 2.701
6000 | 2048 | -0.797 | 0.453
6272 | 2048 | 2.792 | 0.126
200 | 3072 | 0.024 | -0.063
1000 | 3072 | 1.820 | 2.956
6000 | 3072 | 2.531 | 0.275
6272 | 3072 | 1.054 | 7.929
128 | 2097152 | 2.564 | 0.963
256 | 1048576 | 0.077 | 0.582
512 | 524288 | 0.428 | 0.094
1024 | 262144 | 0.581 | -0.096
2048 | 131072 | -0.225 | 0.888
4096 | 65536 | 0.527 | 0.428
8192 | 32768 | 0.204 | 0.717
16384 | 16384 | -0.216 | -0.492
32768 | 8192 | 0.786 | 5.127

</body>

</html>

CC: @jeffdaily

Pull Request resolved: pytorch#87726
Approved by: https://github.com/ngimel
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. pytorch#26201
#       2. pytorch#68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# pytorch#68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# pytorch#27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…or ROCm (pytorch#87726)

We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs.

This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635

We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU.

**At [the previous PR](pytorch#87635
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148
50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761
200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284
802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946
200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449
1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758
6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989
6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615
200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208
1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741
6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078
6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514
200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683
1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598
6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035
6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871
200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514
1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264
6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609
6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997
200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138
1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464
6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338
6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018
200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845
1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823
6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719
6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012
128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848
256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678
512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804
1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699
2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708
4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617
8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643
16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291
32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423

</body>

</html>

----

**At this PR:**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{background:yellow;
	mso-pattern:black none;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222
50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399
200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265
802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331
200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548
1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851
6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927
6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377
200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894
1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564
6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782
6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506
200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019
1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851
6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526
6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926
200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695
1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573
6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814
6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441
200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548
1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452
6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915
6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896
200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869
1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646
6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036
6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039
128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851
256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554
512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784
1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679
2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603
4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326
8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669
16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022
32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386

</body>

</html>

---

**Performance Improvement (%)**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{mso-number-format:"0\.000";}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 9.147 | 12.094
50176 | 384 | 4.710 | 8.230
200704 | 192 | 10.655 | 8.266
802816 | 64 | 14.000 | 9.042
200 | 256 | 1.263 | 2.542
1000 | 256 | 3.109 | -0.246
6000 | 256 | 1.183 | 2.796
6272 | 256 | -3.579 | -3.394
200 | 512 | -5.489 | -7.852
1000 | 512 | -3.270 | -2.240
6000 | 512 | -3.456 | -4.596
6272 | 512 | -0.392 | -2.644
200 | 1024 | -7.862 | -0.969
1000 | 1024 | -1.321 | 0.359
6000 | 1024 | -0.693 | -2.336
6272 | 1024 | -2.130 | -2.034
200 | 1536 | -5.287 | -5.151
1000 | 1536 | -0.683 | -0.829
6000 | 1536 | 2.792 | 6.989
6272 | 1536 | 0.051 | 2.132
200 | 2048 | -5.461 | -1.167
1000 | 2048 | 4.126 | 2.701
6000 | 2048 | -0.797 | 0.453
6272 | 2048 | 2.792 | 0.126
200 | 3072 | 0.024 | -0.063
1000 | 3072 | 1.820 | 2.956
6000 | 3072 | 2.531 | 0.275
6272 | 3072 | 1.054 | 7.929
128 | 2097152 | 2.564 | 0.963
256 | 1048576 | 0.077 | 0.582
512 | 524288 | 0.428 | 0.094
1024 | 262144 | 0.581 | -0.096
2048 | 131072 | -0.225 | 0.888
4096 | 65536 | 0.527 | 0.428
8192 | 32768 | 0.204 | 0.717
16384 | 16384 | -0.216 | -0.492
32768 | 8192 | 0.786 | 5.127

</body>

</html>

CC: @jeffdaily

Pull Request resolved: pytorch#87726
Approved by: https://github.com/ngimel
pytorchmergebot pushed a commit that referenced this pull request Nov 11, 2024
It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: #140259
Approved by: https://github.com/jianyuh
jataylo added a commit to ROCm/pytorch that referenced this pull request Dec 4, 2024
…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
jataylo added a commit to ROCm/pytorch that referenced this pull request Dec 4, 2024
…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
jataylo added a commit to ROCm/pytorch that referenced this pull request Dec 4, 2024
…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
jataylo added a commit to ROCm/pytorch that referenced this pull request Dec 4, 2024
…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Dec 6, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Dec 6, 2024
… kernel (pytorch#140259) (#1767)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Dec 6, 2024
… kernel (pytorch#140259) (#1767)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1767)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

Fixes #ISSUE_NUMBER
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: cuda release notes category release notes: rocm mandatorylabel rocm priority high priority ROCm PRs from performance or other aspects rocm This tag is for PRs from ROCm team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants