Skip to content

Conversation

@fadara01
Copy link
Collaborator

@fadara01 fadara01 commented Jan 29, 2025

This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687 enabled an optimized implementation for qlinear_dynamic for AArch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.

However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature. Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kernel's optimal format) for each GEMM operation.

This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48) - See benchmark code below. To achieve this, we:

  • Use ACL which is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
  • Add ACL to ATen's CPU include and dependency libs
  • Introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.
  • A future PR will introduce a direct ACL implementation qlinear and will allow us to remove the dependence on PackedLinearWeightsOnednn

The following code was used to benchmark qlinear_dynamic performance:

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)         
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))  

Fixes #ISSUE_NUMBER

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jan 29, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145942

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ac5618b with merge base 6c3492b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added release notes: quantization release notes category release notes: releng release notes category labels Jan 29, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Jan 29, 2025
@fadara01 fadara01 force-pushed the acl_qlinear_dynamic branch from 7b44387 to 802e2a6 Compare January 30, 2025 11:35
@fadara01
Copy link
Collaborator Author

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Jan 30, 2025
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 30, 2025
@fadara01
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased acl_qlinear_dynamic onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout acl_qlinear_dynamic && git pull --rebase)

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, mostly looks good, though if possible, please submit a separate PR that updates ACL version.

#pragma once

#include <ATen/Config.h>
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need an arch guard there?

Suggested change
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
#if AT_MKLDNN_ACL_ENABLED()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not, I removed it. AT_MKLDNN_ACL_ENABLED is enough

int64_t // NUM_THREADS
>;

enum ACLDynamicQuantMatmulCacheKeyIndex {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enum ACLDynamicQuantMatmulCacheKeyIndex {
enum class ACLDynamicQuantMatmulCacheKeyIndex {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

return dim == 2 ? output : output.reshape(output_size);
}

#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this needs arch check?

Suggested change
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
#if AT_MKLDNN_ACL_ENABLED()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need it, I removed it. AT_MKLDNN_ACL_ENABLED is enough

Comment on lines 61 to 62
if (with_bias) {
bia_tensor.allocator()->free();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to express something like that by defining bias tensor as std::optional<arm_compute::Tensor> bias_tensor;?

Copy link
Collaborator Author

@fadara01 fadara01 Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I now use std::optional for bia_tensor and bia_tensor_info

fadara01 added a commit to fadara01/pytorch that referenced this pull request Feb 20, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
fadara01 added a commit to fadara01/pytorch that referenced this pull request Feb 20, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
fadara01 added a commit to fadara01/pytorch that referenced this pull request Feb 26, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
fadara01 added a commit to fadara01/pytorch that referenced this pull request Feb 27, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See pytorch#145942, pytorch#147337, pytorch#146620.
This patch enables such use cases by exposing ACL to ATen
This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR pytorch#126687 enabled an optimized implementation for qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48).
To achieve this we introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.
@fadara01 fadara01 force-pushed the acl_qlinear_dynamic branch from 1542c78 to ac5618b Compare March 5, 2025 11:29
fadara01 added a commit to fadara01/pytorch that referenced this pull request Mar 5, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR pytorch#145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 5, 2025

I created a standalone PR #148542 for all cmake related changes and removed them from here to ease the review process.

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know much about this particular codepath, but requesting changes solely for the integration strategy. (speaking of strategy, it would be good to have an RFC issue outlining ACL/oneDNN integration - i.e. what is the end goal: fully decouple ACL from oneDNN or keep some direct usage until oneDNN integration is done, or it it about something else)

So back to integration:

  • Please move logic that searches fro ACL into a separate PR (you have write permissions, so you can you ghstack, can't you) and use modern cmake (that defines target rather than global variables) to introduce new dependency
  • Avoid explicit memory management (i.e. if something needs to free the memory, wrap it into a simple unique_ptr)
  • Avoid implementing methods in headers unless those are inline methods or templates
  • Also, as much as possible please use Torch memory allocator, rather than mix ACL and Torch ones, as it will make memory tracking/reporting easier

Last but not least: you've added the script that benchmarks the perf, but did not share the numbers before and after, that would help one understand the benefits this PR brings


int64_t k_;
int64_t n_;
int64_t wei_zero_point_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a document outlining the naming convention? Underscore usually means private variable, but they are public, as this is a struct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I meant to actually make them private.
I addressed this in the new ghstack PR, please see this line

# FindACL
# ----------
#
# Finds the Arm Compute Library
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this file exists somewhere else? If so, please reference where it was copied from
If this is create exclusively for PyTorch, please use modern CMAKE, i.e. instead of (or in addition to) defining global variables add libraries/targets, something like

    add_library(ArmComputeLib INTERFACE)
    target_link_libraries(ArmComputeLib INTERFACE ${ACL_LIBRARIES})
    target_include_directories(ArmComputeLib INTERFACE ${ACL_INCLUDE_DIRS})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was not exclusively create for PyTorch, it was copied from oneDNN: https://github.com/oneapi-src/oneDNN/blob/main/cmake/FindACL.cmake

I referenced that in the new ghstack PR - see this line

~ACLDynamicQuantMatmul() {
// this will free memory allocated for the quantized src tensor since the
// allocation happened through ACL: src_s8_tensor.allocator()->allocate()
src_s8_tensor.allocator()->free();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very unsafe programming model: there are no default constructor, so this structure could be allocated uninitialized, and then freed. I have not tried it, but it's very likely if someone writes something like

{
   ACLDynamicQuantMatmul v;
}

it will crash in the destructor, as wei_tensor has not be allocated, but it's allocator()->free() methods is called unconditionally

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, sorry for the confusion, my comment above is not right.
tensor.allocator()->free() never frees any memory (whether that memory was allocated by ACL or not). It just tells ACL that we're no longer using the pointer - See here - this can't lead to crashes.

The memory allocated by ACL is freed automatically.

I agree the structure here is not nice.
I added constructors and made sure all memory allocations happen through PyTorch - See here

@@ -0,0 +1,257 @@
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add some sort of a comment explaining what ACL is (in computing, this acronym is most commonly associated with access control lists, see https://en.wikipedia.org/wiki/ACL ) and what classes/functions defined in this header are supposed to do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done in the ghstack PR here

at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
override;

std::shared_ptr<ACLDynamicQuantMatmul> get_acl_dynamic_quant_matmul(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious what is the though process here of having implementation in the header vs respective CPP file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial PoC was simple enough that it did not require a cpp file for implementation.
I agree that the implementation has gotten complex enough that it needs a cpp file.
I addressed this in the new ghstack PR, please see here


std::shared_ptr<ACLDynamicQuantMatmul> get_acl_dynamic_quant_matmul(
const ACLDynamicQuantMatmulCacheKey& key) {
// We're only maintaining a 2 element LRU cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for nitpicks. Few thoughts:

  • LRU cache idea does not seem to be unique to ACL, so implementation should exist someplace else. If not, please add the implementation to say c10/utils/lru_cache.h and then use it here
  • Can variable name be shorter here? (just cache or quant_cache?
  • Again, naming convention, as variable is private, shouldn't its name end with _

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LRU cache is not unique to ACL indeed. I could not find an implementation to use and given that our LRU cache impl just keeps track of two elements only, I don't see the point of making it global to PyTorch outside ACLUtils.h
If we end up implementing a (real) more complex LRU cache in the future, we'll add it to c10/utils/lru_cache.h and use it from there.

I agree with your comments about the name, it is now cache_ in the new ghstack PR, please see this line

Comment on lines +113 to +116
std::rotate(
acl_dynamic_quant_cache.begin(),
acl_dynamic_quant_cache.begin() + 1,
acl_dynamic_quant_cache.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If your cache size is two, wouldn't std::swap(cache[0], cache[1]) would be an equivalent?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, thank you!
I addressed this in the new ghstack PR here

&acl_gemm->dst_tensor_info,
&acl_gemm->dst_tensor_info,
acl_gemm->acl_relu_info);
if (relu_status.error_code() != arm_compute::ErrorCode::OK) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want to add TORCH_WARN or something, so that users know something went wrong?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Idea, done here


// validate that ACL can handle the given problem and inputs.
if (fuse_relu) {
arm_compute::Status relu_status =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use auto there?

Suggested change
arm_compute::Status relu_status =
auto relu_status =

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done here


// allocate memory only for the quantized tensor, the rest will use memory
// already avaliable from PyTorch
acl_gemm->src_s8_tensor.allocator()->allocate();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why use ACL allocator instead of an existing PyTorch one? So that memory tracking story is cleaner, i.e. all memory is allocated/tracked and freed by PyTorch caching allocator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you.
I now use PyTorch for all allocations - ACL just import pointers but does not explicitly allocate/deallocate any memory.
See here

@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 5, 2025

Last but not least: you've added the script that benchmarks the perf, but did not share the numbers before and after, that would help one understand the benefits this PR brings

Please check the PR description above, where I say:

This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48) - See benchmark code below.

fadara01 added a commit that referenced this pull request Mar 5, 2025
ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See #145942, #147337, #146620.
This patch enables such use cases by exposing ACL to ATen

ghstack-source-id: 266c621
Pull Request resolved: #148581
fadara01 added a commit that referenced this pull request Mar 5, 2025
…tly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR #145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.

ghstack-source-id: 05435a0
Pull Request resolved: #148586
davsva01 added a commit to davsva01/pytorch that referenced this pull request Mar 5, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
It's based on changes in PR pytorch#145942 which enables the use of ACL directly in ATen.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 6, 2025

Please move logic that searches fro ACL into a separate PR (you have write permissions, so you can you ghstack, can't you) and use modern cmake (that defines target rather than global variables) to introduce new dependency

I created ghtack PRs for this:

I addressed reviews on this PR in #148584 - it's equivalent ghstack one.

Unfortunately, I couldn't convert these PRs to be ghtack PRs because they're on feature branches belonging to my PyTorch fork, hence the new PRs.

@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 6, 2025

what is the end goal: fully decouple ACL from oneDNN or keep some direct usage until oneDNN integration is done, or it it about something else)

The plan for now is to enable this fast path until direct fast path to ACL from oneDNN is implemented

@fadara01 fadara01 requested a review from malfet March 7, 2025 14:06
pytorchmergebot pushed a commit that referenced this pull request Mar 10, 2025
ACL is already built with PyTorch as a shared library when USE_MKLDNN_ACL is set.
Currently, it is only used indirectly in ATen via oneDNN for AArch64 targets. However there are cases where it makes sense to utilize ACL directly without  oneDNN as an intermediary - e.g. quantization. See #145942, #147337, #146620.
This patch enables such use cases by exposing ACL to ATen

Pull Request resolved: #148584
Approved by: https://github.com/malfet
@fadara01
Copy link
Collaborator Author

Closing in favor of ghstack PR #148585 which has all comments addressed

@fadara01 fadara01 closed this Mar 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

arm priority ciflow/linux-aarch64 linux aarch64 CI workflow module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category release notes: releng release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants