Skip to content

Conversation

@Rohanjames1997
Copy link
Contributor

@Rohanjames1997 Rohanjames1997 commented Jul 19, 2023

Fixes #104729

As suggested in the blog, I subclassed the VecISA class and implemented a NEON version of the vec_reduce_all() function, to go along with the existing AVX2 and AVX512 versions. Any operation that calls vec_reduce_all() will also take the NEON path and benefit from its vectorization.

The vec_reduce_all() is invoked by Softmax and other operations like norms. Using the fast path results in 30% time savings for Softmax as compared to the previously taken slow path.

  Slow path Fast path (NEON intrinsics)
Softmax (100 passes, 1024 dimension) 623.706ms 452.011ms

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang @Xia-Weiwen @ngimel @malfet

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105590

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit e5628d0 with merge base 24d5cab (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jul 19, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@github-actions github-actions bot added module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor ciflow/inductor labels Jul 19, 2023
@github-actions
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, will stamp after we make sure the changes are covered by UT and the PR passes the UT.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 20, 2023
@Rohanjames1997
Copy link
Contributor Author

Thanks for the review @jgong5
I decided to extend the existing UT infrastructure for NEON as well. Since my changes involve editing some critical cmake files as well, I've raised the changes in a different PR and my code passes the UT vec_test_all_types_NEON.

Let me know if you'd like me to merge those changes into this PR, or keep them separate.
Thanks!

@jgong5
Copy link
Collaborator

jgong5 commented Jul 24, 2023

Let me know if you'd like me to merge those changes into this PR, or keep them separate.

How about make them separate and let the other PR land first?

@Rohanjames1997
Copy link
Contributor Author

@jgong5 , Sure I have raised the changes exclusively required to create a UT in this PR: #105823

@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2023

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@kit1980
Copy link
Contributor

kit1980 commented Feb 23, 2024

@pytorchmergebot merge -i

@Rohanjames1997 please be careful with ignoring, you ignored a lint failure.

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2024

@pytorchbot revert -m "#121288 (comment)"

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 7, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2024

@pytorchbot revert -m "#121288 (comment)" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@Rohanjames1997 your PR has been successfully reverted.

@pytorch-bot pytorch-bot bot dismissed stale reviews from jgong5 and malfet March 7, 2024 23:06

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@Rohanjames1997 Rohanjames1997 requested review from jgong5 and malfet March 12, 2024 17:21
Rohanjames1997 added a commit to Rohanjames1997/pytorch that referenced this pull request Mar 12, 2024
…rch#105590)

Fixes pytorch#104729

As suggested in the [blog](https://dev-discuss.pytorch.org/t/torchinductor-update-5-cpu-backend-backend-performance-update-and-deep-dive-on-key-optimizations/1117#:~:text=It%20can%20be,sub%2Dclasses.), I subclassed the `VecISA` class and implemented a NEON version of the `vec_reduce_all()` function, to go along with the existing AVX2 and AVX512 versions. Any operation that calls `vec_reduce_all()` will also take the NEON path and benefit from its vectorization.

The `vec_reduce_all()` is invoked by Softmax and other operations like norms. Using the fast path results in 30% time savings for Softmax as compared to the previously taken slow path.

  | Slow path | Fast path (NEON intrinsics)
-- | -- | --
Softmax (100 passes, 1024 dimension) | 623.706ms | 452.011ms

Pull Request resolved: pytorch#105590
Approved by: https://github.com/jgong5, https://github.com/malfet
#include <c10/util/TypeCast.h>

#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all # elif defined(CPU_CAPABILITY_ZVECTOR) in the masked_load functions in this file (e.g., https://github.com/pytorch/pytorch/pull/105590/files#diff-e384e0d2829ef483854a45c4f422979d1a2ca28495c7bc06e91eabc67c61a470R320), change them to # else so that the default path can work for NEON too.

malfet pushed a commit that referenced this pull request Mar 19, 2024
This is a re-land of #105590 but
this time enbaling it only for Darwin platform where those instructions
are available by default
malfet pushed a commit that referenced this pull request Mar 22, 2024
This is a re-land of #105590 but
this time enbaling it only for Darwin platform where those instructions
are available by default
malfet pushed a commit that referenced this pull request Mar 25, 2024
This is a re-land of #105590 but
this time enbaling it only for Darwin platform where those instructions
are available by default
malfet added a commit that referenced this pull request Mar 25, 2024
This is a re-land of #105590 but
this time enbaling it only for Darwin platform where those instructions
are available by default
pytorchmergebot pushed a commit that referenced this pull request Mar 26, 2024
This started as a re-land of #105590 but focusing on enabling it on MacOS, but quickly turned into landing very limited platform-specific acceleration at this time (I.e. this PR does not add any NEON accelerated code at all, just enables vectorized compilation for the existing abstractions)

Enabling the test harness, uncovered number of latent issues in CPU inductor that were fixed in the following PRS:
- #122511
- #122513
- #122580
- #122608

Following was added/changed to enable vectorization code to work on MacOS
 - Added VecNEON class to `_inductor/codecache.py`  that is supported on all AppleSilicon Macs
 - Added `Vectorized::loadu_one_fourth` to `vec_base.h`, and limit it to 8-bit types
 - Change 64-bit integral types mapping to `int64_t`/`uint64_t` to align with the rest of the code, as on MacOS, `int64_t` is a `long long` rather than `long` (see #118149 for more details)

See table below for perf changes with and without torch.compile using [gpt-fast](https://github.com/pytorch-labs/gpt-fast) running `stories15M` on M2 Pro:
| dtype  | Eager | Compile (before) | Compile (after) |
| ------ | ------ | --------- | --------- |
| bfloat16  | 120 tokens/sec  | 130 tokens/sec | 156 tokens/sec |
| float32  | 158 tokens/sec  | 140 tokens/sec | 236 tokens/sec |
| float16  | 235 tokens/sec  | 81 tokens/sec | 58 tokens/sec |

Pull Request resolved: #122217
Approved by: https://github.com/jansel
@Rohanjames1997
Copy link
Contributor Author

Refer to #123584

pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
This started as a re-land of #105590 but focusing on enabling it on MacOS, but quickly turned into landing very limited platform-specific acceleration at this time (I.e. this PR does not add any NEON accelerated code at all, just enables vectorized compilation for the existing abstractions)

Enabling the test harness, uncovered number of latent issues in CPU inductor that were fixed in the following PRS:
- #122511
- #122513
- #122580
- #122608

Following was added/changed to enable vectorization code to work on MacOS
 - Added VecNEON class to `_inductor/codecache.py`  that is supported on all AppleSilicon Macs
 - Added `Vectorized::loadu_one_fourth` to `vec_base.h`, and limit it to 8-bit types
 - Change 64-bit integral types mapping to `int64_t`/`uint64_t` to align with the rest of the code, as on MacOS, `int64_t` is a `long long` rather than `long` (see #118149 for more details)

See table below for perf changes with and without torch.compile using [gpt-fast](https://github.com/pytorch-labs/gpt-fast) running `stories15M` on M2 Pro:
| dtype  | Eager | Compile (before) | Compile (after) |
| ------ | ------ | --------- | --------- |
| bfloat16  | 120 tokens/sec  | 130 tokens/sec | 156 tokens/sec |
| float32  | 158 tokens/sec  | 140 tokens/sec | 236 tokens/sec |
| float16  | 235 tokens/sec  | 81 tokens/sec | 58 tokens/sec |

Pull Request resolved: #122217
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: inductor Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for NEON ISA in the Inductor C++ backend

9 participants