Skip to content

Conversation

@fadara01
Copy link
Collaborator

@fadara01 fadara01 commented Feb 17, 2025

This enables a fast path for eager mode statically quantized matmuls for AArch64 through Arm Compute Library (ACL) directly.

PR #145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pre-transpositions and reductions) by enabling a path that calls ACL directly.
This does the same thing and addresses the same overheads for (static) qlinear.

I benchmarked this PR (ACL direct integration for static quantization in ATen) against the current state of PyTorch (with #147498 which updates oneDNN to v3.7 included because it's a much stronger baseline than the current oneDNN version in PyTorch which is v3.5.3). See benchmarking script below.
My benchmark runs statically quantized linears for all combinations of M = [8, 16, ..., 512], K = [768, 1024, 2048, 4096], N = [768, 1024, 2048, 4096].

This PR gives an average speedup of 2x for signed activations (s8s8s8) and 95x for unsigned activations (u8s8u8) on a Neoverse-V1 with 16 threads.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for u8s8u8 on AArch64.

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized


if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    
    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)
    
    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147337

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0b06ae1 with merge base 6c3492b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category release notes: releng release notes category labels Feb 17, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Feb 17, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Feb 17, 2025
pytorchmergebot pushed a commit that referenced this pull request Feb 19, 2025
Among many things, this version of ACL fixes the redundant declaration  warning that we're blocked on in (#145942, #146620, #147337) and introduces better scheduling heuristics for GEMMs

Fixes #ISSUE_NUMBER

Pull Request resolved: #147454
Approved by: https://github.com/malfet
@fadara01 fadara01 requested a review from malfet February 20, 2025 12:41
@fadara01 fadara01 force-pushed the acl_qlinear_static branch 2 times, most recently from bccb8ed to 457a00b Compare February 20, 2025 13:19
@fadara01
Copy link
Collaborator Author

@pytorchbot label "arm priority"

@fadara01 fadara01 force-pushed the acl_qlinear_static branch 2 times, most recently from 3310ef2 to 4c484f2 Compare February 20, 2025 19:07
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 24, 2025
Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good at a high level. Left some comments.

Do we have existing python tests and CI setups which guarantee we test these new impls?

}
}

at::Tensor PackedLinearWeightsACL::apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be yet another option in addition to existing alternatives to run this quantized operator on Arm CPUs. What are your thoughts, at a high level from maintenance point of view, on consolidating some of these options? Or even directly using Kleidi kernels here with at::parallel and other native utils?

Copy link
Collaborator Author

@fadara01 fadara01 Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation for this approach is to address the overheads from the incompatibilities between the oneDNN and ACL lowp_gemm APIs in the current path (without this PR): PyTorch -> ideep -> oneDNN -> ACL (See motivation on #145942).

oneDNN's API is stateless, while ACL's lowp_gemm API is stateful. The current path utilizes ACL's lowp_gemm operators through oneDNN. Given that oneDNN is stateless, ACL's lowp_gemm cannot have state and will need to run (redundant) weights pre-transpositions and reductions on every forward call. This PR enables ACL's lowp_gemm directly from PyTorch, which allows ACL to cache pre-transpositions and reductions for each layer, instead of recomputing them on every forward call.

What are your thoughts, at a high level from maintenance point of view, on consolidating some of these options? Or even directly using Kleidi kernels here with at::parallel and other native utils?

Unfortunately, KleidiAI currently has minimal support for int8 GEMMs in general, and no kernels for int8 statically quantized GEMMs.
Once those kernels are in KleidiAI they will be part of ACL anyway and will be utilized by the path this PR enables.
We will then benchmark the (direct) PyTorch -> KleidiAI path against the PyTorch -> ACL -> KleidiAI path. If it turns out that ACL is adding overhead, we'll integrate KleidiAI kernels directly here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my concern was around maintainability, so many different ways to run this going through build and runtime selection just for Arm CPUs. While I understand the rationale, we should also aim to actively trim dead paths.

For example, I saw somewhere that it is possible to fallback to onednn when ACL path isn't valid, can we support those through ACL and remove onednn path.

Copy link
Collaborator Author

@fadara01 fadara01 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also aim to actively trim dead paths.

Yes, I agree with you.
In this PR, we're extending the existing PackedLinearWeightsOneDNN with our new PackedLinearWeightsACL rather than replacing it entirely. This is because ACL's support for per-channel quantized weights is still incomplete. As you can see here, we only prepack to PackedLinearWeightsACL when weights are per-tensor affine quantized. Once ACL provides full support for per-channel quantized weights, we will drop the dependency on PackedLinearWeightsOneDNN and make ACL the default (and only) option for this path - i.e. PackedLinearWeightsACL extends LinearPackedParamsBase directly

it is possible to fallback to onednn when ACL path isn't valid, can we support those through ACL and remove onednn path.

Regarding the fallback to PackedLinearWeightsOneDNN::apply or PackedLinearWeightsOneDNN::apply_dynamic, it's not strictly needed as I haven’t encountered any workload causing ACL's validation to fail. I included these fallbacks purely as a precaution. Since our implementation extends PackedLinearWeightsOneDNN, I wanted to guarantee that the behavior remains exactly the same as the previous implementation in all cases (plus the perf improvements)

@fadara01
Copy link
Collaborator Author

fadara01 commented Feb 26, 2025

Do we have existing python tests and CI setups which guarantee we test these new impls?

Yes, current tests exercise this impl when AT_MKLDNN_ACL_ENABLED is set (it's set to 1 by default for AArch64 builds).
This PR fixes a currently (disabled on AArch64) but failing quantization test - See #145216

CMakeLists.txt Outdated
if(USE_MKLDNN_ACL)
find_package(ACL REQUIRED)
if(ACL_FOUND)
include_directories(${ACL_INCLUDE_DIRS})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really bad pattern, what's wrong with target_include_directories(torch_cpu PRIVATE $(ACL_INCLUDE_DIRS})?

Copy link
Collaborator Author

@fadara01 fadara01 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I moved this from the root CMakeLists.txt to the caffe2 one where torch_cpu is defined.
I removed the redundant if(ACL_FOUND) check and used target_include_directories(torch_cpu PRIVATE $(ACL_INCLUDE_DIRS}) instead of include_directories(${ACL_INCLUDE_DIRS}) as you suggested.

Does it look better now?

@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 4, 2025

Thanks for your feedback @digantdesai , @malfet
I addressed your comments, could you please have another look?

cc: @milpuz01

ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See pytorch#145942, pytorch#147337, pytorch#146620.
This patch enables such use cases by exposing ACL to ATen
fadara01 added 2 commits March 5, 2025 11:27
This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR pytorch#126687 enabled an optimized implementation for qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48).
To achieve this we introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
@fadara01 fadara01 force-pushed the acl_qlinear_static branch from 460ed57 to 0b06ae1 Compare March 5, 2025 11:43
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 5, 2025

@malfet I created a standalone PR #148542 for all cmake related changes and removed them from here to ease the review process.

fadara01 added a commit that referenced this pull request Mar 5, 2025
ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See #145942, #147337, #146620.
This patch enables such use cases by exposing ACL to ATen

ghstack-source-id: 266c621
Pull Request resolved: #148581
pytorchmergebot pushed a commit that referenced this pull request Mar 10, 2025
ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See #145942, #147337, #146620.
This patch enables such use cases by exposing ACL to ATen

Pull Request resolved: #148584
Approved by: https://github.com/malfet
@fadara01
Copy link
Collaborator Author

Closing in favor of ghstack PR #148585 which has all comments addressed

@fadara01 fadara01 closed this Mar 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

arm priority ciflow/linux-aarch64 linux aarch64 CI workflow module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category release notes: releng release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants