Skip to content

Accelerate FP16 matmul via cblas_sgemm for Apple AMX#340

Merged
HenryNdubuaku merged 1 commit intocactus-compute:mainfrom
KayaanT:amx-matmul-f16
Feb 11, 2026
Merged

Accelerate FP16 matmul via cblas_sgemm for Apple AMX#340
HenryNdubuaku merged 1 commit intocactus-compute:mainfrom
KayaanT:amx-matmul-f16

Conversation

@KayaanT
Copy link
Copy Markdown
Contributor

@KayaanT KayaanT commented Feb 11, 2026

Summary

  • Routes large FP16 matmuls (K>=256, M>=4) through Apple Accelerate (cblas_sgemm), which uses AMX internally
  • Converts FP16 to FP32, calls cblas_sgemm, converts result back to FP16
  • Small matrices fall through to existing NEON path unchanged
  • Same pattern used by the existing conv1d kernel (kernel_conv.cpp)
  • Reference: cblas_sgemm

Contributes to #298

Benchmark (M4 Pro, 1024x1024x1024)

Benchmark Before (NEON) After (AMX) Speedup
MatMul F16 1024³ 9.997ms / 215 GFLOPS 2.524ms / 851 GFLOPS 4.0x
MatMul 1024³ CPU (graph) 10.438ms / 206 GFLOPS 2.419ms / 888 GFLOPS 4.3x
MatMul F16 1x1024x1024 (GEMV) 0.046ms / 46 GFLOPS 0.040ms / 52 GFLOPS same (NEON path)

Copilot AI review requested due to automatic review settings February 11, 2026 03:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an Apple-specific fast path for large FP16 matrix multiplications by converting inputs to FP32, calling Accelerate’s cblas_sgemm (leveraging AMX internally), then converting the FP32 result back to FP16—while keeping the existing NEON implementation for smaller shapes.

Changes:

  • Add __APPLE__-guarded Accelerate (cblas_sgemm) matmul path for large FP16 matmuls.
  • Convert __fp16 inputs to float buffers and cast the FP32 output back to __fp16.
  • Leave existing NEON tiled/parallel implementation as fallback for smaller matrices.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +171 to +174
std::vector<float> A_f32(a_len);
std::vector<float> BT_f32(b_len);
std::vector<float> C_f32(c_len);

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Accelerate path allocates three large std::vector<float> buffers on every call (A, B^T, C). For large matrices this can add significant allocator overhead and memory pressure. Consider reusing scratch buffers (e.g., thread-local or a caller-provided workspace) and/or using Accelerate/vDSP conversion routines to reduce per-call overhead.

Suggested change
std::vector<float> A_f32(a_len);
std::vector<float> BT_f32(b_len);
std::vector<float> C_f32(c_len);
static thread_local std::vector<float> A_f32;
static thread_local std::vector<float> BT_f32;
static thread_local std::vector<float> C_f32;
A_f32.resize(a_len);
BT_f32.resize(b_len);
C_f32.resize(c_len);

Copilot uses AI. Check for mistakes.
Comment on lines +166 to +167
if (K >= 256 && M >= 4) {
const size_t a_len = M * K;
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thresholds K >= 256 and M >= 4 are hard-coded magic numbers. To make this easier to tune and keep consistent with other Accelerate thresholds (e.g., in kernel_conv.cpp), define them as named constexpr constants (and ideally document the rationale/benchmark behind them).

Copilot uses AI. Check for mistakes.
Comment on lines +165 to +186
#ifdef __APPLE__
if (K >= 256 && M >= 4) {
const size_t a_len = M * K;
const size_t b_len = N * K;
const size_t c_len = M * N;

std::vector<float> A_f32(a_len);
std::vector<float> BT_f32(b_len);
std::vector<float> C_f32(c_len);

for (size_t i = 0; i < a_len; i++) A_f32[i] = (float)a[i];
for (size_t i = 0; i < b_len; i++) BT_f32[i] = (float)b_transposed[i];

cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
(int)M, (int)N, (int)K,
1.0f, A_f32.data(), (int)K,
BT_f32.data(), (int)K,
0.0f, C_f32.data(), (int)N);

for (size_t i = 0; i < c_len; i++) c[i] = (__fp16)C_f32[i];
return;
}
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces a new Apple-only execution path with different numerics (FP32 accumulate + FP16 cast) and a different backend (cblas_sgemm), but there are currently no correctness tests for FP16 matmul. Adding a guarded (#ifdef __APPLE__) test that exercises the Accelerate threshold region (e.g., M>=4, K>=256) and compares against a reference implementation would help prevent silent regressions.

Copilot uses AI. Check for mistakes.
Comment on lines +178 to +182
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
(int)M, (int)N, (int)K,
1.0f, A_f32.data(), (int)K,
BT_f32.data(), (int)K,
0.0f, C_f32.data(), (int)N);
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cblas_sgemm takes int dimensions/leading dimensions, but this code casts from size_t without bounds checks. If M/N/K exceed INT_MAX (or even just K/N for lda/ldb/ldc), the cast can overflow and lead to incorrect results or memory errors. Add a guard (e.g., if (M > INT_MAX || N > INT_MAX || K > INT_MAX) fall back to the existing NEON path) before calling BLAS.

Copilot uses AI. Check for mistakes.
Routes large FP16 matmuls (K>=256, M>=4) through Apple's Accelerate
framework which uses AMX internally. Converts FP16→FP32, calls
cblas_sgemm, converts back. Small matrices fall through to existing
NEON path. Benchmarked 4x speedup on 1024³ (215→851 GFLOPS).

Signed-off-by: Kayaan Tharani <[email protected]>
@HenryNdubuaku
Copy link
Copy Markdown
Collaborator

Thanks for this @KayaanT youve now learnt how we love code contributions, most tasks are designed no need no more than 1-3 file changes, so I'll merge this.

@HenryNdubuaku HenryNdubuaku merged commit 0e7cdfd into cactus-compute:main Feb 11, 2026
1 of 2 checks passed
KayaanT added a commit to KayaanT/cactus that referenced this pull request Feb 13, 2026
Batch query positions into real GEMMs so cblas_sgemm can use AMX,
same approach as the matmul kernel (cactus-compute#340). Parallelizes over
batch * num_q_heads instead of batch * num_q_heads * seq_len,
with each work item processing all seq_len positions for one head.

Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x)

Falls back to existing NEON path for seq_len < 64.

Signed-off-by: Kayaan Tharani <[email protected]>
HenryNdubuaku added a commit that referenced this pull request Feb 18, 2026
* Accelerate FP16 attention via cblas_sgemm for Apple AMX

Batch query positions into real GEMMs so cblas_sgemm can use AMX,
same approach as the matmul kernel (#340). Parallelizes over
batch * num_q_heads instead of batch * num_q_heads * seq_len,
with each work item processing all seq_len positions for one head.

Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x)

Falls back to existing NEON path for seq_len < 64.

Signed-off-by: Kayaan Tharani <[email protected]>

* Add ACCELERATE_NEW_LAPACK definition for improved LAPACK support

Signed-off-by: HenryNdubuaku <[email protected]>

---------

Signed-off-by: Kayaan Tharani <[email protected]>
Signed-off-by: HenryNdubuaku <[email protected]>
Co-authored-by: HenryNdubuaku <[email protected]>
ncylich pushed a commit that referenced this pull request Feb 24, 2026
* Accelerate FP16 attention via cblas_sgemm for Apple AMX

Batch query positions into real GEMMs so cblas_sgemm can use AMX,
same approach as the matmul kernel (#340). Parallelizes over
batch * num_q_heads instead of batch * num_q_heads * seq_len,
with each work item processing all seq_len positions for one head.

Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x)

Falls back to existing NEON path for seq_len < 64.

Signed-off-by: Kayaan Tharani <[email protected]>

* Add ACCELERATE_NEW_LAPACK definition for improved LAPACK support

Signed-off-by: HenryNdubuaku <[email protected]>

---------

Signed-off-by: Kayaan Tharani <[email protected]>
Signed-off-by: HenryNdubuaku <[email protected]>
Co-authored-by: HenryNdubuaku <[email protected]>
cattermelon1234 pushed a commit to cattermelon1234/cactus that referenced this pull request Feb 28, 2026
…te#346)

* Accelerate FP16 attention via cblas_sgemm for Apple AMX

Batch query positions into real GEMMs so cblas_sgemm can use AMX,
same approach as the matmul kernel (cactus-compute#340). Parallelizes over
batch * num_q_heads instead of batch * num_q_heads * seq_len,
with each work item processing all seq_len positions for one head.

Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x)

Falls back to existing NEON path for seq_len < 64.

Signed-off-by: Kayaan Tharani <[email protected]>

* Add ACCELERATE_NEW_LAPACK definition for improved LAPACK support

Signed-off-by: HenryNdubuaku <[email protected]>

---------

Signed-off-by: Kayaan Tharani <[email protected]>
Signed-off-by: HenryNdubuaku <[email protected]>
Co-authored-by: HenryNdubuaku <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants