Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Sep 13, 2019

Originally implemented in:
https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu

Changes:

  • Copy-paste the implementation from APEX to ATen. Codes are modified a bit to fit into ATen (for example: add support for the case grad_out is undefined, and to make it compatible with ROCm).
  • API change: at::layer_norm no longer requires the bool cudnn_enabled as an argument. (do I need to keep this argument for back compatibility? It is no longer used.)
  • layer_norm is now infinitely differentiable, but if higher order derivatives are asked (create_graph=True when computing backward), the performance of backwards is no longer optimized. (described below)
  • Some minor maintainability thing:
    • The variables currently are named quite inconsistent, such as X vs input, M vs n1, gamma vs weight, I renamed a lot of them to make them consistent.
    • Also inconsistency in code style: )\n{ vs ) {, I changed all of them to )\n{

About differentiability:

Current implementation:

  • CPU: twice differentiable, optimized forward, backward and double backward code.
  • GPU: infinitely differentiable, not optimized for performance

This PR:
Both CPU and GPU code now becomes infinitely differentiable and has optimized forward. And for backward, if create_graph set to false, then autograd will use the optimized code for computing backwards. But if create_graph set to true, then autograd will use a fallback infinitely differentiable implementation of the backward using ATen operators, which is not optimized for performance. This behavior is similar to what is done in WeightNorm (See: #10842).

Due to this change, the optimized CPU code for double backward becomes a dead code and is removed.

Benchmark

Code (Jupyter Notebook):

import torch
from torch.nn import LayerNorm
import warnings
import gc


LINE_WIDTH = 80
warnings.filterwarnings('ignore')
print('PyTorch version:', torch.__version__)
print()


def benchmark(*sizes):
    print('=' * LINE_WIDTH)
    print("Benchmarking input shape", sizes)
    normalized_shape = sizes[1:]
    layer_norm_cuda = LayerNorm(normalized_shape).cuda()
    layer_norm_cpu = LayerNorm(normalized_shape).cpu()
    
    input_cuda = torch.randn(*sizes, device='cuda', requires_grad=True)
    input_cpu = torch.randn(*sizes, device='cpu', requires_grad=True)

    input_bytes = input_cuda.numel() * input_cuda.element_size()
    print("Element size", input_cuda.element_size())
    print("Size of the input tensor is", input_bytes, "bytes")

    print('-' * LINE_WIDTH)

    print("cuda forward:")
    %timeit layer_norm_cuda(input_cuda); torch.cuda.synchronize()
    print("cpu forward:")
    %timeit layer_norm_cpu(input_cpu)

    print('-' * LINE_WIDTH)
    
    out_cuda = layer_norm_cuda(input_cuda)
    out_cpu = layer_norm_cpu(input_cpu)
    upstream_grad_cuda = torch.randn_like(out_cuda)
    upstream_grad_cpu = torch.randn_like(out_cpu)
    
    print('cuda backward, create_graph=False:')
    %timeit out_cuda.backward(upstream_grad_cuda, retain_graph=True); torch.cuda.synchronize()
    gc.collect()
    print('cpu backward, create_graph=False:')
    %timeit out_cpu.backward(upstream_grad_cpu, retain_graph=True)
    gc.collect()

    print('-' * LINE_WIDTH)
    
    print('cuda backward, create_graph=True:')
    %timeit out_cuda.backward(upstream_grad_cuda, retain_graph=True, create_graph=True); torch.cuda.synchronize()
    gc.collect()
    print('cpu backward, create_graph=True:')
    %timeit out_cpu.backward(upstream_grad_cpu, retain_graph=True, create_graph=True)
    gc.collect()

    print('=' * LINE_WIDTH)
    print()


benchmark(100, 100)
benchmark(1000, 100)
benchmark(100, 500)

Result on torch-nightly installed by pip:

PyTorch version: 1.3.0.dev20190920

================================================================================
Benchmarking input shape (100, 100)
Element size 4
Size of the input tensor is 40000 bytes
--------------------------------------------------------------------------------
cuda forward:
146 µs ± 16 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
55.1 µs ± 5.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
392 µs ± 66.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=False:
104 µs ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
1.01 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=True:
146 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
================================================================================

================================================================================
Benchmarking input shape (1000, 100)
Element size 4
Size of the input tensor is 400000 bytes
--------------------------------------------------------------------------------
cuda forward:
116 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
191 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
210 µs ± 7.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=False:
298 µs ± 1.24 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
735 µs ± 53.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=True:
341 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
================================================================================

================================================================================
Benchmarking input shape (100, 500)
Element size 4
Size of the input tensor is 200000 bytes
--------------------------------------------------------------------------------
cuda forward:
82 µs ± 958 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
101 µs ± 200 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
209 µs ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=False:
168 µs ± 2.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
980 µs ± 337 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=True:
7.18 ms ± 1.84 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
================================================================================

Result for this PR:

PyTorch version: 1.3.0a0+77e6902

================================================================================
Benchmarking input shape (100, 100)
Element size 4
Size of the input tensor is 40000 bytes
--------------------------------------------------------------------------------
cuda forward:
33.7 µs ± 97.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
26.3 µs ± 66.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
112 µs ± 3.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu backward, create_graph=False:
57.7 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
1.15 ms ± 68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=True:
369 µs ± 7.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
================================================================================

================================================================================
Benchmarking input shape (1000, 100)
Element size 4
Size of the input tensor is 400000 bytes
--------------------------------------------------------------------------------
cuda forward:
37.4 µs ± 95.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
136 µs ± 49.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
116 µs ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu backward, create_graph=False:
201 µs ± 7.78 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
The slowest run took 4.15 times longer than the fastest. This could mean that an intermediate result is being cached.
1.84 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
cpu backward, create_graph=True:
592 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
================================================================================

================================================================================
Benchmarking input shape (100, 500)
Element size 4
Size of the input tensor is 200000 bytes
--------------------------------------------------------------------------------
cuda forward:
34.5 µs ± 91.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu forward:
70.4 µs ± 25.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=False:
118 µs ± 5.52 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cpu backward, create_graph=False:
131 µs ± 7.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
--------------------------------------------------------------------------------
cuda backward, create_graph=True:
1.72 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cpu backward, create_graph=True:
471 µs ± 45.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
================================================================================

some cleanups

some refactor

more

more

more

more

more

more

more

more

more

more

more

more

more

fixes

more

more

style fix

fixes

fixes

fixes

fixes

fixes

fix functional
@pytorchbot pytorchbot added module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: operators oncall: jit Add this issue/PR to JIT oncall triage queue module: onnx Related to torch.onnx labels Sep 13, 2019
remove tabs

fix dispatch in native_functions

fix grads

fix jit signature of aten::layer_norm

Fix ROCm build

revert changes of test_torch

fix more at jit

try fix bindings

fix binding name

fix bindings

remove double backward kernel

infinitely_differentiable_native_layer_norm_backward

fix onnx

fix more
@zasdfgbnm zasdfgbnm force-pushed the fused-layernorm branch 2 times, most recently from 6148ba6 to e757eeb Compare September 18, 2019 18:37
clean

some fixes

try fix grad

try fix

fix

fix

cleanup

more cleanup

more cleanup

fix rsqrt binding

fix

cleanup
try infinitely_differentiable_native_layer_norm_backward

fix infinitely_differentiable_native_layer_norm_backward

fixes

fix

fix

infinitely_differentiable_native_layer_norm_backward seems working

fixes

enable non_differentiable_native_layer_norm_backward

fix  native_functions.yaml

fix
@zasdfgbnm zasdfgbnm changed the title [WIP] Port fused layer_norm from APEX to ATen Port fused layer_norm from APEX to ATen Sep 20, 2019
@zasdfgbnm zasdfgbnm marked this pull request as ready for review September 20, 2019 21:45
@zasdfgbnm zasdfgbnm requested a review from apaszke as a code owner September 20, 2019 21:45
@zasdfgbnm
Copy link
Collaborator Author

@pytorchbot rebase this please

@zasdfgbnm
Copy link
Collaborator Author

@pytorchbot rebase this please

@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 25, 2019
@zasdfgbnm
Copy link
Collaborator Author

Closing in favor of #27634

@zasdfgbnm zasdfgbnm closed this Oct 16, 2019
@zasdfgbnm zasdfgbnm deleted the fused-layernorm branch October 16, 2019 23:38
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR #68238](#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- #26201
- #27634
- #68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. #26201
#       2. #68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# #68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# #27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: #87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 1, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. pytorch#26201
#       2. pytorch#68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# pytorch#68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# pytorch#27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants