Skip to content

Conversation

@yanbing-j
Copy link
Collaborator

@yanbing-j yanbing-j commented Aug 31, 2024

This PR is per ARM request, which is in intel/ideep#334.

Context for the request is: Arm team has upstreamed the dynamic quantization changes, all the PRs were merged (torch, ideep, oneDNN), but without this ideep submodule update, the feature will not work. The change is isolated to only matmul operator and quantization path alone.

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @Guobing-Chen @Xia-Weiwen @snadampal @malfet @milpuz01

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134897

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit eee651e with merge base 85fa019 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/linux-aarch64 linux aarch64 CI workflow module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration topic: not user facing topic category labels Aug 31, 2024
@yanbing-j yanbing-j added ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel labels Aug 31, 2024
@yanbing-j yanbing-j self-assigned this Aug 31, 2024
@milpuz01
Copy link
Contributor

milpuz01 commented Sep 2, 2024

@pytorchbot label "arm"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 2, 2024

Didn't find following labels among repository labels: arm

@milpuz01
Copy link
Contributor

milpuz01 commented Sep 2, 2024

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Sep 2, 2024
@yanbing-j yanbing-j marked this pull request as ready for review September 3, 2024 01:03
@yanbing-j yanbing-j requested review from atalman and malfet September 3, 2024 01:04
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 3, 2024
@yanbing-j
Copy link
Collaborator Author

Hi @snadampal @milpuz01 Please provide ARM test results.
Hi @malfet @atalman Please kindly review.

@yanbing-j
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased yanbing/update_ideep onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout yanbing/update_ideep && git pull --rebase)

@fadara01
Copy link
Collaborator

fadara01 commented Sep 4, 2024

Our acceptance tests currently do not include dynamic quantization.
Given that this change only affects the dynamic quantization path and does not change the version of oneDNN, running the full acceptance tests is not required.

I manually verified this PR and can confirm that (as expected) oneDNN calls Arm Compute Library's optimized lowp gemm kernels and on 16 Neoverse-V1 cores, the speedup for bert-large is as follows:

context length bert-large speedup with this PR
8 20.5x
16 26.9x
32 31.1x
64 40.7x
128 53.1x
256 50.0x
512 27.2x

I think the manual tests above and the CI tests are enough for us to to merge this PR

cc: @snadampal @milpuz01 @malfet @atalman @yanbing-j

@snadampal
Copy link
Collaborator

@fadara01 thanks for the data! the two CI failures don't seem be to related to this PR.

@yanbing-j
Copy link
Collaborator Author

@fadara01 Thanks for the data!

@malfet @atalman Could you please help review this PR? Thanks!

Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@snadampal snadampal self-requested a review September 6, 2024 14:54
Copy link
Collaborator

@snadampal snadampal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

@atalman
Copy link
Contributor

atalman commented Sep 6, 2024

@pytorchmergebot merge -f "failures are not related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 12, 2024
)

Optimized dynamic quantization for aarch64 was enabled by #126687 and #134897

This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel

Example:
```python
import torch

DIM = 4096
INPUT_SIZE1 = 32
INPUT_SIZE2 = 16

class LinearNet(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(DIM, DIM, bias=False)

   def forward(self, x):
        x = self.fc1(x)
        return x

input1 = torch.randn(size=(INPUT_SIZE1, DIM))
input2 = torch.randn(size=(INPUT_SIZE2, DIM))

with torch.no_grad():
    model = LinearNet()
    model =  torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear})

    model(input1)   # this goes to ACL lowp_gemm
    print("="*50)
    model(input2)   # this goes to gemm:jit without this PR, and to ACL with this PR
```
In the code snippet above:
- The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR)
- The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit.

Pull Request resolved: #135058
Approved by: https://github.com/jondea, https://github.com/malfet
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
This PR is per ARM request, which is in intel/ideep#334.

Context for the request is: Arm team has upstreamed the dynamic quantization changes, all the PRs were merged (torch, ideep, oneDNN), but without this ideep submodule update, the feature will not work. The change is isolated to only matmul operator and quantization path alone.

Pull Request resolved: pytorch#134897
Approved by: https://github.com/jgong5, https://github.com/atalman, https://github.com/snadampal
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
…rch#135058)

Optimized dynamic quantization for aarch64 was enabled by pytorch#126687 and pytorch#134897

This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel

Example:
```python
import torch

DIM = 4096
INPUT_SIZE1 = 32
INPUT_SIZE2 = 16

class LinearNet(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(DIM, DIM, bias=False)

   def forward(self, x):
        x = self.fc1(x)
        return x

input1 = torch.randn(size=(INPUT_SIZE1, DIM))
input2 = torch.randn(size=(INPUT_SIZE2, DIM))

with torch.no_grad():
    model = LinearNet()
    model =  torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear})

    model(input1)   # this goes to ACL lowp_gemm
    print("="*50)
    model(input2)   # this goes to gemm:jit without this PR, and to ACL with this PR
```
In the code snippet above:
- The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR)
- The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit.

Pull Request resolved: pytorch#135058
Approved by: https://github.com/jondea, https://github.com/malfet
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR is per ARM request, which is in intel/ideep#334.

Context for the request is: Arm team has upstreamed the dynamic quantization changes, all the PRs were merged (torch, ideep, oneDNN), but without this ideep submodule update, the feature will not work. The change is isolated to only matmul operator and quantization path alone.

Pull Request resolved: pytorch#134897
Approved by: https://github.com/jgong5, https://github.com/atalman, https://github.com/snadampal
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…rch#135058)

Optimized dynamic quantization for aarch64 was enabled by pytorch#126687 and pytorch#134897

This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel

Example:
```python
import torch

DIM = 4096
INPUT_SIZE1 = 32
INPUT_SIZE2 = 16

class LinearNet(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(DIM, DIM, bias=False)

   def forward(self, x):
        x = self.fc1(x)
        return x

input1 = torch.randn(size=(INPUT_SIZE1, DIM))
input2 = torch.randn(size=(INPUT_SIZE2, DIM))

with torch.no_grad():
    model = LinearNet()
    model =  torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear})

    model(input1)   # this goes to ACL lowp_gemm
    print("="*50)
    model(input2)   # this goes to gemm:jit without this PR, and to ACL with this PR
```
In the code snippet above:
- The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR)
- The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit.

Pull Request resolved: pytorch#135058
Approved by: https://github.com/jondea, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants