Skip to content

Conversation

@xiaomengy
Copy link
Contributor

Summary:
Seperate from D15194600
Optimize pytorch layer_norm op part 1:
optimize layer_norm_forward_cpu
import Eigen Maps for the performance of reduction

Differential Revision: D15290608

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR now introduces Eigen into pytorch optimizations -- without discussion.
The previous comments of moving to native/cpu and using dispatcher are not addressed

@xiaomengy xiaomengy requested a review from soumith May 15, 2019 17:23
@xiaomengy
Copy link
Contributor Author

Removed Eigen now. For reduce case, I found that even using Vec256 is still slower than Eigen. So maybe we can consider to discuss or figure out how to make it better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move LayerNormForwardCPUImpl and all CPU logic into the /cpu subfolder using DECLARE_DISPATCH logic (see CopyKernel for the reference). It will allow to utilize AVX2 instructions and other optimizations of OSS build.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite sure you want to call at:native::empty_like here and avoid all tracing/profiling checks.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR requires new tests to be written for native_layer_norm

@cpuhrsch
Copy link
Contributor

@BIT-silence - Is this a copy of the C2 kernels?

@VitalyFedyunin
Copy link
Contributor

You are exposing native_layer_norm function to the user level api which seems to be requiring contiguous inputs. Please add tests to cover APi calls as well as contiguous checks. Alternatively if intend is to introduce fast kernel, implement all code as kernel and avoid creating a new native function.

Also it would be nice to see benchmarks that compares calling old and new layer_norm code.

@xiaomengy
Copy link
Contributor Author

@BIT-silence - Is this a copy of the C2 kernels?

Logically it is, while C2 kernel is implemented by using Eigen lib. I tested the performance difference, for the rowwise moments part, Eigen version will be faster than this version by using compiler's auto vectorization, elementwise affine part performs the same. While I also tried to use Vec256 in rowwise moments part, it actually a little slower than the for-loop with auto vectorization.

@pytorchbot pytorchbot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 16, 2019
@xiaomengy
Copy link
Contributor Author

Thanks for the advice. I have removed native_layer_norm function. Later I will add the backward part for layer_norm then do autograd for layer_norm instead of batch_norm.

Some benchmark result for this change.
input shape = [64, 128, 56, 56] and normalized_shape = [128, 56, 56] with elementwise_affine=True,
on devvm forward time from 350ms to 87.6ms. And with this approach, it can let hugging face BERT model forward about 12% faster.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making the changes. From my side this is good to go.

@xiaomengy
Copy link
Contributor Author

This PR requires new tests to be written for native_layer_norm

Actually I have removed native_layer_norm. Currently we don't expose new functions and add a fast path for layer_norm. Is that fine?

@fmassa
Copy link
Member

fmassa commented May 16, 2019

We should ideally test that the fast path gives the same results as the slow path. But maybe this is implicitly tested in the cuda tests

@xiaomengy
Copy link
Contributor Author

We should ideally test that the fast path gives the same results as the slow path. But maybe this is implicitly tested in the cuda tests

Actually not only cuda, but also jit test covers that.

@xiaomengy
Copy link
Contributor Author

Since I will add the grad part for this fast path and then all the layer_norm on CPU should go through this path. As the current fast path is actually tested by existing tests, I'm wondering if there is any need to add some specific test for that.

Summary:
Pull Request resolved: pytorch#20345

Seperate from D15194600
Optimize pytorch layer_norm op part 1:
  optimize layer_norm_forward_cpu
  import Eigen Maps for the performance of reduction

Reviewed By: zheng-xq

Differential Revision: D15290608

fbshipit-source-id: d5589f67c515644403ff3ad11006ec43bab18809
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in c9da011.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 22, 2019
Summary:
Pull Request resolved: pytorch/pytorch#20345

Seperate from D15194600
Optimize pytorch layer_norm op part 1:
  optimize layer_norm_forward_cpu
  import Eigen Maps for the performance of reduction

Reviewed By: zheng-xq

Differential Revision: D15290608

fbshipit-source-id: cf2c208dfd6fbcbc4c69db3ed60278d9bee156b5
vkuzo added a commit that referenced this pull request Mar 24, 2020
Summary:

Adds a quantized implementation of LayerNorm for server.

Relevant PRs:
* #20345 (floating point LN)
* #33080 (quantized BN)

A future PR will add the Python wrapper.

Test Plan:

numerics match the floating point implementation
TODO: benchmarks

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2020
Summary:

Adds a quantized implementation of LayerNorm for server.

Relevant PRs:
* #20345 (floating point LN)
* #33080 (quantized BN)

A future PR will add the Python wrapper.

Test Plan:

numerics match the floating point implementation
TODO: benchmarks

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2020
Summary:

Adds a quantized implementation of LayerNorm for server.

Relevant PRs:
* #20345 (floating point LN)
* #33080 (quantized BN)

A future PR will add the Python wrapper.

Test Plan:

numerics match the floating point implementation
TODO: benchmarks

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2020
Summary:

Adds a quantized implementation of LayerNorm for server.

Relevant PRs:
* #20345 (floating point LN)
* #33080 (quantized BN)

A future PR will add the Python wrapper.

Test Plan:

numerics match the floating point implementation
TODO: benchmarks

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3c3721f
Pull Request resolved: #35329
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpu CPU specific problem (e.g., perf, algorithm)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants