Skip to content

Optimize some reduction operators on CPU BFloat16#55202

Closed
mingfeima wants to merge 14 commits intogh/mingfeima/16/basefrom
gh/mingfeima/16/head
Closed

Optimize some reduction operators on CPU BFloat16#55202
mingfeima wants to merge 14 commits intogh/mingfeima/16/basefrom
gh/mingfeima/16/head

Conversation

@mingfeima
Copy link
Copy Markdown
Collaborator

@mingfeima mingfeima commented Apr 2, 2021

Stack from ghstack:

Differential Revision: D28836790

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 2, 2021

💊 CI failures summary and remediations

As of commit f2c070c (more details on the Dr. CI page and at hud.pytorch.org/pr/55202):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

mingfeima added a commit that referenced this pull request Apr 2, 2021
ghstack-source-id: 15f6804
Pull Request resolved: #55202
@mingfeima
Copy link
Copy Markdown
Collaborator Author

This PR aims at enabling or optimizing the following reduction operators on CPU BFloat16:

  • softmax (don't support BFloat16 previously)
  • log_softmax
  • max (reduce_all)
  • min (reduce_all)
  • _aminmax (reduce_all; don't support BFloat16 previously)

Note that previously some operators e.g. log_softmax already have BFloat16 support, but the perf is not good.
This PR specializes all the map<>() from vec256/functional.h and also corresponding test cases are added in vec256_test_all_types.cpp.

Since we have already specializes all members for Vec256<scalar_t> with {scalar_t=BFloat16} in vec256_bfloat16.h, so the following example would run smoothing on BFloat16:

using Vec = Vec256<BFloat16>;
Vec one = Vec(BFloat16(1));
vec256::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);

The current impl will end up with 3 pairs of dtype conversion, each for ".exp()", "+" and "/" respectively.
The new impl will only need dtype conversion for input and output. Benefits:

  • better performance since we have less dtype conversion;
  • less rounding error since immediate results are kept in fp32;
  • accumulation done on data type of fp32.

I am going to use this stack to push more BFloat16 CPU optimizations, following the same manner as propose in this PR:

  • input load: bf16->fp32; output store: fp32->bf16
  • all immediate operations (including accumulation) will use fp32

@mingfeima
Copy link
Copy Markdown
Collaborator Author

mingfeima commented Apr 2, 2021

Since this PR is not related to parallelization feature, only single core perf is tested:
NB: if the operator doesn't have BFloat16 support previously, the before perf refers to simple impl as following:

-   AT_DISPATCH_ALL_TYPES(input.scalar_type(), "_aminmax_all_all", [&] {
+   AT_DISPATCH_ALL_TYPES_AND(kBFloat16, input.scalar_type(), "_aminmax_all_all", [&] {
  • performance update on avx512 machine: Xeon(R) Gold 6248 CPU @ 2.50GHz
before: softmax: 128x1024: fp32: 151.457 us; bf16: 362.440 us
after:  softmax: 128x1024: fp32: 151.757 us; bf16: 194.105 us

before: log_softmax: 128x1024: fp32: 157.474 us; bf16: 411.537 us
after:  log_softmax: 128x1024: fp32: 152.229 us; bf16: 163.657 us

before: max: 128x1024: fp32: 23.714 us; bf16: 63.077 us
after:  max: 128x1024: fp32: 24.523 us; bf16: 17.484 us

before: min: 128x1024: fp32: 23.707 us; bf16: 63.198 us
after:  min: 128x1024: fp32: 24.067 us; bf16: 17.498 us

before: _aminmax: 128x1024: fp32: 25.156 us; bf16: 69.929 us
after:  _aminmax: 128x1024: fp32: 25.272 us; bf16: 19.487 us
  • performance update on avx2 machine: Xeon(R) CPU E5-2680 v3 @ 2.50GHz
before: softmax: 128x1024: fp32: 229.048 us; bf16: 493.946 us
after:  softmax: 128x1024: fp32: 248.060 us; bf16: 293.789 us

before: log_softmax: 128x1024: fp32: 274.467 us; bf16: 673.550 us
after:  log_softmax: 128x1024: fp32: 251.243 us; bf16: 241.080 us

before: max: 128x1024: fp32: 36.288 us; bf16: 88.725 us
after:  max: 128x1024: fp32: 32.229 us; bf16: 23.905 us

before: min: 128x1024: fp32: 32.179 us; bf16: 87.669 us
after:  min: 128x1024: fp32: 30.829 us; bf16: 22.847 us

before: _aminmax: 128x1024: fp32: 33.098 us; bf16: 105.431 us
after:  _aminmax: 128x1024: fp32: 30.785 us; bf16: 28.113 us

Notes: With this PR, BFloat16 Softmax is still slower than float32, because lack of native dtype conversion intrinsics (currently fp32/bf16 conversion uses emulated method), on Sapphire Rapids BFloat16 is faster.

mingfeima added a commit to mingfeima/pytorch that referenced this pull request Apr 28, 2021
@mingfeima mingfeima requested a review from VitalyFedyunin May 13, 2021 03:01
@mingfeima
Copy link
Copy Markdown
Collaborator Author

@VitalyFedyunin Could you please review this stack? We are trying to optimize CPU BFloat16 path performance.

dgl-intel pushed a commit to dgl-intel/pytorch that referenced this pull request May 14, 2021
@mdschatz
Copy link
Copy Markdown
Contributor

@mdschatz has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

imaginary-person added a commit to imaginary-person/pytorch-1 that referenced this pull request May 28, 2021
@VitalyFedyunin
Copy link
Copy Markdown
Contributor

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Copy Markdown
Contributor

Hi! Please rebase to adapt to more general vectorization as we are in process of introducing AVX512 support.

@mingfeima
Copy link
Copy Markdown
Collaborator Author

@VitalyFedyunin Hi, this stack has been rebased!

Copy link
Copy Markdown
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some code moves required, otherwise looks fine

@mingfeima mingfeima changed the title Optimize some redunction operators on CPU BFloat16 [WIP] Optimize some redunction operators on CPU BFloat16 Jun 17, 2021
@mingfeima mingfeima changed the title [WIP] Optimize some redunction operators on CPU BFloat16 Optimize some reduction operators on CPU BFloat16 Jun 18, 2021
@mingfeima
Copy link
Copy Markdown
Collaborator Author

mingfeima commented Jun 18, 2021

@VitalyFedyunin, updated! Also add map4 test in vec256_test_all_types_XXX since other PR got merged. Please check.

@VitalyFedyunin
Copy link
Copy Markdown
Contributor

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mdschatz
Copy link
Copy Markdown
Contributor

mdschatz commented Jun 24, 2021

@mingfeima, I've noticed that including vec256_float.h is needed in vec256_bfloat16.h in order to get compilation to proceed; without it, I am getting a lot of __m256 type-cast errors. Should that line be included in this PR or not? @jgong5 Can you elaborate on this?

I'd defer to @VitalyFedyunin to decide how to properly address though.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@VitalyFedyunin merged this pull request in 5a077bb.

@jgong5
Copy link
Copy Markdown
Collaborator

jgong5 commented Jun 25, 2021

@mdschatz The compilation error was due to an incorrect header file inclusion in IPEX (should include ATen/cpu/vec/vec.h instead of vec256_bfloat16.h directly, and we have fixed it there. So, no worries!

@mingfeima
Copy link
Copy Markdown
Collaborator Author

@mdschatz The compilation error was due to an incorrect header file inclusion in IPEX (should include ATen/cpu/vec/vec.h instead of vec256_bfloat16.h directly, and we have fixed it there. So, no worries!

right, you should include <ATen/cpu/vec/vec.h> instead of directly include 'vec256_xxx.h'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants