Skip to content

Conversation

@fadara01
Copy link
Collaborator

@fadara01 fadara01 commented Sep 3, 2024

Optimized dynamic quantization for aarch64 was enabled by #126687 and #134897

This PR fixes an issue for aarch64 where on a cache miss (e.g. if input dimensions change) ideep::matmul_forward::compute (wrongly) runs with the default lowp_kind (u8s8) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel

Example:

import torch

DIM = 4096
INPUT_SIZE1 = 32
INPUT_SIZE2 = 16

class LinearNet(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(DIM, DIM, bias=False)

   def forward(self, x):
        x = self.fc1(x)
        return x

input1 = torch.randn(size=(INPUT_SIZE1, DIM))
input2 = torch.randn(size=(INPUT_SIZE2, DIM))

with torch.no_grad():
    model = LinearNet()
    model =  torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear})
    
    model(input1)   # this goes to ACL lowp_gemm
    print("="*50)
    model(input2)   # this goes to gemm:jit without this PR, and to ACL with this PR

In the code snippet above:

  • The matmul from model(input1) goes to oneDNN+ACL (in both cases, with and without the PR)
  • The matmul from model(input2): Without this PR: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, With this PR the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm

This fixes an issue for aarch64 where on a cache miss (e.g. if input dimensions change)
ideep::matmul_forward::compute runs with the default lowp_kind (u8s8) which
is not supported by oneDNN+ACL, casusing the workload to fall back to a much slower
oneDNN gemm:jit kernel
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135058

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 511af4e with merge base e7731b3 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 3, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: fadara01 / name: Fadi Arafeh (511af4e)

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Sep 3, 2024
@fadara01
Copy link
Collaborator Author

fadara01 commented Sep 3, 2024

cc @malfet @atalman @jondea @cfRod @milpuz01 Please kindly review.

@fadara01
Copy link
Collaborator Author

fadara01 commented Sep 3, 2024

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Sep 3, 2024
Copy link
Contributor

@jondea jondea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find, thank you!

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 6, 2024
@fadara01
Copy link
Collaborator Author

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2024

Can't add following labels to PR: ciflow/linux-aarch64. Please ping one of the reviewers for help.

@cfRod
Copy link
Collaborator

cfRod commented Sep 12, 2024

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2024

Can't add following labels to PR: ciflow/linux-aarch64. Please ping one of the reviewers for help.

@cfRod
Copy link
Collaborator

cfRod commented Sep 12, 2024

@malfet We cant seem to add CI labels

@malfet
Copy link
Contributor

malfet commented Sep 12, 2024

@malfet We cant seem to add CI labels

@cfRod you'll need to approve workflow run first..

@malfet malfet added topic: bug fixes topic category ciflow/linux-aarch64 linux aarch64 CI workflow labels Sep 12, 2024
@malfet
Copy link
Contributor

malfet commented Sep 12, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • CPU ATen backend (mingfeima, XiaobingSuper, jgong5, vfdev-5, leslie-fang-intel)
  • CPU inductor (leslie-fang-intel, jgong5, EikanWang)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed PR description, my only question will it still work (although slowly) on older ARMv8 platforms like Cortex A75?

@fadara01
Copy link
Collaborator Author

@malfet Thank you for the detailed PR description, my only question will it still work (although slowly) on older ARMv8 platforms like Cortex A75?

Yes, it should work for older Arm platforms too.

@malfet
Copy link
Contributor

malfet commented Sep 13, 2024

@pytorchbot revert -m "It regresses x86 performance" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Sep 13, 2024
@pytorchmergebot
Copy link
Collaborator

@fadara01 your PR has been successfully reverted.

@fadara01
Copy link
Collaborator Author

fadara01 commented Sep 17, 2024

@malfet could we please get a reproducer for the regression on x86?
My understanding is that ideep::forward_matmul::compute will be called with the same default arguments on x86 with and without this PR.

@malfet
Copy link
Contributor

malfet commented Sep 18, 2024

could we please get a reproducer for the regression on x86?

Sorry, this is an internal test, I can not share full reproducer.

My understanding is that ideep::forward_matmul::compute will be called with the same default arguments on x86 with and without this PR.

I'm not very familiar with this code, but wasn't it calling https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3/include/ideep/operators/matmul.hpp#L63 before your PR, wasn't it?

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…rch#135058)

Optimized dynamic quantization for aarch64 was enabled by pytorch#126687 and pytorch#134897

This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel

Example:
```python
import torch

DIM = 4096
INPUT_SIZE1 = 32
INPUT_SIZE2 = 16

class LinearNet(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(DIM, DIM, bias=False)

   def forward(self, x):
        x = self.fc1(x)
        return x

input1 = torch.randn(size=(INPUT_SIZE1, DIM))
input2 = torch.randn(size=(INPUT_SIZE2, DIM))

with torch.no_grad():
    model = LinearNet()
    model =  torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear})

    model(input1)   # this goes to ACL lowp_gemm
    print("="*50)
    model(input2)   # this goes to gemm:jit without this PR, and to ACL with this PR
```
In the code snippet above:
- The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR)
- The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit.

Pull Request resolved: pytorch#135058
Approved by: https://github.com/jondea, https://github.com/malfet
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
@fadara01
Copy link
Collaborator Author

fadara01 commented Nov 4, 2024

I'm not very familiar with this code, but wasn't it calling https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3/include/ideep/operators/matmul.hpp#L63 before your PR, wasn't it?

@malfet I don't think this is the case, as the ideep::matmul_forward::compute function you mentioned does not match the arguments we're passing from qlinear_dynamic.cpp in the cache miss (else) branch.

Are you using a vanilla wheel for your internal tests with x86?

If I pip install a wheel on x86, this path (ideep/oneDNN) is not even selected for dynamic quantization (can be deduced by the lack of oneDNN verbose when running with the environment variable ONEDNN_VERBOSE=1).
I think the fbgemm path is chosen instead.

@Chao1Han do you have any insights on why this might cause regressions on x86?

I'm happy to wrap the new arguments with an #ifdef __aarch64__ but I still can't understand why that is necessary.

@fadara01
Copy link
Collaborator Author

@malfet , Any updates or thoughts?

@fadara01
Copy link
Collaborator Author

fadara01 commented Dec 28, 2024

@malfet, I built torch on x86 with USE_FBGEMM=0 to start exercising this ideep path and confirmed with print statements that this function gets called before and after my PR. Hence, I do not understand how this is causing regressions.

Could you please have a more serious look at this?
It brings 10x speedups for LLMs on aarch64, since all matmuls after prefill are currently getting dispatched to sub-optimal implementations due to the [wrong] default u8s8 lowp_kind.

@robert-hardwick
Copy link
Collaborator

Raised an issue that we have seen which will be fixed by this change #145216

Not a regression. Those tests have not been enabled before in CI.

@robert-hardwick
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to

Aborting rebase because rebasing the branch resulted in the same sha as the target branch.
This usually happens because the PR has already been merged.  Please rebase locally and push.

Raised by https://github.com/pytorch/pytorch/actions/runs/12928955569

@fadara01 fadara01 requested a review from malfet January 23, 2025 15:25
@robert-hardwick
Copy link
Collaborator

i think the rebase failed because this has been previously merged and subsequently reverted

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Mar 24, 2025
@github-actions github-actions bot closed this Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category Reverted Stale topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants