Skip to content

Conversation

@jianyuh
Copy link
Member

@jianyuh jianyuh commented Nov 4, 2019

Stack from ghstack:

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Single Core:
(Note that our benchmark generates batch_size=47 for first case and batch_size=56 for the second case. In spite of that, the vectorized version is still faster than the original reference C version without vectorization.)

  • Before the PR:
native_layer_norm        0.81%            5.884ms          0.81%            5.884ms          122.580us        NaN              0.000us          0.000us          48               [[47, 1, 1024], [1024], [1024]]
  • After the PR:
native_layer_norm        0.68%            5.053ms          0.68%            5.053ms          105.272us        NaN              0.000us          0.000us          48               [[56, 1, 1024], [1024], [1024]]

20 Cores:

  • Before the PR:
native_layer_norm        1.65%            41.682ms         1.65%            41.682ms         868.365us        NaN              0.000us          0.000us          48               [[61, 64, 1024], [1024], [1024]]
  • After the PR:
native_layer_norm        1.34%            33.829ms         1.34%            33.829ms         704.771us        NaN              0.000us          0.000us          48               [[61, 64, 1024], [1024], [1024]]

Differential Revision: D18293522

…ec256

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 4, 2019
…ec256

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

ghstack-source-id: 93164082
Pull Request resolved: #29104
@jianyuh jianyuh requested a review from xiaomengy November 4, 2019 01:12
…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 4, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 93167727
…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 4, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 93170831
…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 4, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 93176468
…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 10, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 93608529
@jianyuh
Copy link
Member Author

jianyuh commented Nov 10, 2019

I can reproduce it on a Skylake machine before my PR with

python run_test.py -i nn -- TestNN.test_LayerNorm_1d_no_elementwise_affine_eval

Output:

$ python run_test.py -i nn -- TestNN.test_LayerNorm_1d_no_elementwise_affine_eval
which: no nvcc in (/root/miniconda/bin:/root/miniconda/condabin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/facebook/ops/scripts:/usr/facebook/scripts:/opt/local/bin:/usr/facebook/scripts:/usr/facebook/scripts/db:/root/bin)
Test executor: ['/root/miniconda/bin/python']
Running test_nn ... [2019-11-09 21:25:55.865797]
F
======================================================================
FAIL: test_LayerNorm_1d_no_elementwise_affine_eval (__main__.TestNN)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_nn.py", line 7344, in <lambda>
    add(test_name, lambda self, test=test: test(self))
  File "/root/jhuang_test/pytorch/test/common_nn.py", line 3511, in __call__
    self.test_noncontig(test_case, module, input)
  File "/root/jhuang_test/pytorch/test/common_nn.py", line 3570, in test_noncontig
    test_case.assertEqual(out, output)
  File "/root/jhuang_test/pytorch/test/common_utils.py", line 736, in assertEqual
    assertTensorsEqual(x, y)
  File "/root/jhuang_test/pytorch/test/common_utils.py", line 706, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(76.1324, grad_fn=<MaxBackward1>) not less than or equal to 1e-05 :

----------------------------------------------------------------------
Ran 1 test in 0.009s

FAILED (failures=1)
Traceback (most recent call last):
  File "run_test.py", line 455, in <module>
    main()
  File "run_test.py", line 447, in main
    raise RuntimeError(message)
RuntimeError: test_nn failed!

…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Nov 10, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 93611515
@jianyuh
Copy link
Member Author

jianyuh commented Nov 10, 2019

Fixed the issue by reusing vec256::reduce_all and vec256::map_reduce_all. But there might be overhead as the fusion between two loops is dismissed.

@jianyuh jianyuh requested a review from xiaomengy November 10, 2019 09:23
Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Does this improve some benchmarks?

@jianyuh jianyuh dismissed xiaomengy’s stale review November 21, 2019 05:45

Will check the performance.

@jianyuh
Copy link
Member Author

jianyuh commented Nov 21, 2019

LGTM. Does this improve some benchmarks?

Will check the performance before landing. Thanks!

…ion using Vec256"

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
@jianyuh
Copy link
Member Author

jianyuh commented Dec 11, 2019

Update the performance number in the summary.

…ion using Vec256"


We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.


Single Core:
(Note that our benchmark generates batch_size=47 for first case and batch_size=56 for the second case. In spite of that, the vectorized version is still faster than the original reference C version without vectorization.)
- Before the PR:
```
native_layer_norm        0.81%            5.884ms          0.81%            5.884ms          122.580us        NaN              0.000us          0.000us          48               [[47, 1, 1024], [1024], [1024]]
```

- After the PR:
```
native_layer_norm        0.68%            5.053ms          0.68%            5.053ms          105.272us        NaN              0.000us          0.000us          48               [[56, 1, 1024], [1024], [1024]]
```


20 Cores:
- Before the PR:
```
native_layer_norm        1.65%            41.682ms         1.65%            41.682ms         868.365us        NaN              0.000us          0.000us          48               [[61, 64, 1024], [1024], [1024]]
```


- After the PR:
```
native_layer_norm        1.34%            33.829ms         1.34%            33.829ms         704.771us        NaN              0.000us          0.000us          48               [[61, 64, 1024], [1024], [1024]]
```

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Dec 11, 2019
…ec256

Pull Request resolved: #29104

We would like to provide the vectorized implementation for layer norm. This PR reuses #23349.

Differential Revision: [D18293522](https://our.internmc.facebook.com/intern/diff/D18293522/)
ghstack-source-id: 95345939
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in d6d6075.

@facebook-github-bot facebook-github-bot deleted the gh/jianyuh/44/head branch December 14, 2019 15:15
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
…29104)

Summary:
Pull Request resolved: pytorch#29104

We would like to provide the vectorized implementation for layer norm. This PR reuses pytorch#23349.

Test Plan:
buck test mode/dev-nosan //caffe2/test:nn -- "LayerNorm"

buck test mode/dev-nosan //caffe2/test:nn -- "test_LayerNorm_1d_no_elementwise_affine_eval"

 python run_test.py -i nn -- TestNN.test_LayerNorm_1d_no_elementwise_affine_eval

Differential Revision: D18293522

fbshipit-source-id: f4cfed6e62bac1b43ee00c32b495ecc836bd9ec5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants