Skip to content

Conversation

@fadara01
Copy link
Collaborator

@fadara01 fadara01 commented Mar 18, 2025

This is a backport for the PRs enabling a fast path for eager mode static/dynamic quantized matmuls and quantized add for AArch64 through Arm Compute Library (ACL) directly - #148585, #148653.

PR #148584 is the base for all of the above and made its way to release/2.7, but we need the above two PRs to capitalize on it.

It would mean a lot for us to have these changes in v2.7. They directly enable business partners to adopt PyTorch on Arm as they accelerate MLPerf's recommender model by ~ 14x

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149435

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 601e078 with merge base 924a247 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Mar 18, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Mar 18, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Mar 18, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait until #149417 lands first and then it needs to be integrated into the cherry-pick

@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 18, 2025

Let's wait until #149417 lands first and then it needs to be integrated into the cherry-pick

Cool, I'll include this here as soon as it lands.

@fadara01
Copy link
Collaborator Author

Hi @malfet - I added #149417 to this cherry-pick.

@fadara01
Copy link
Collaborator Author

test_call_jax_pytree failure doesn't seem related.

@fadara01 fadara01 requested a review from malfet March 20, 2025 13:01
@malfet
Copy link
Contributor

malfet commented Mar 27, 2025

@pytorchbot rebase -b release/2.7

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/release/2.7. Check the current status here

fadara01 and others added 3 commits March 27, 2025 00:22
…tly (pytorch#148585)

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs pytorch#126687, pytorch#139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation.
However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`.

- **For `qlinear_dynamic` (dynamically quantized matmuls):**

This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:
```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))
```

- **For `qlinear` (statically quantized matmuls):**

This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64.

```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()

    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)

    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))
```

Pull Request resolved: pytorch#148585
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <[email protected]>
(cherry picked from commit 08a644a)
…48653)

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <[email protected]>

Pull Request resolved: pytorch#148653
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#148585

(cherry picked from commit 6c2db8f)
To fix internal build failures, where per-op headers are not generated.
We really should have lint for something like that.

Test Plan: CI

Reviewed By: izaitsevfb

Differential Revision: D71406882

Pull Request resolved: pytorch#149417
Approved by: https://github.com/Skylion007, https://github.com/izaitsevfb

(cherry picked from commit 5db3a4a)
@pytorchmergebot
Copy link
Collaborator

Successfully rebased qlinear_and_qadd_backports onto refs/remotes/origin/release/2.7, please pull locally before adding more changes (for example, via git checkout qlinear_and_qadd_backports && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the qlinear_and_qadd_backports branch from 0c843ad to 601e078 Compare March 27, 2025 00:22
@malfet malfet merged commit 1b84fd1 into pytorch:release/2.7 Mar 28, 2025
108 of 110 checks passed
@fadara01
Copy link
Collaborator Author

Thank you so much @malfet!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

arm priority ciflow/linux-aarch64 linux aarch64 CI workflow module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants