Accelerate FP16 attention via cblas_sgemm for Apple AMX#346
Accelerate FP16 attention via cblas_sgemm for Apple AMX#346HenryNdubuaku merged 3 commits intocactus-compute:mainfrom
Conversation
Batch query positions into real GEMMs so cblas_sgemm can use AMX, same approach as the matmul kernel (cactus-compute#340). Parallelizes over batch * num_q_heads instead of batch * num_q_heads * seq_len, with each work item processing all seq_len positions for one head. Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x) Falls back to existing NEON path for seq_len < 64. Signed-off-by: Kayaan Tharani <[email protected]>
There was a problem hiding this comment.
Pull request overview
This PR accelerates FP16 attention computation on Apple Silicon by leveraging the AMX coprocessor through the Accelerate framework's cblas_sgemm. It follows the same optimization strategy as PR #340 (matmul kernel), converting FP16 to FP32, batching operations into large GEMMs, and converting back to FP16.
Changes:
- Adds Apple Accelerate-optimized attention kernel for seq_len >= 64
- Batches query positions into matrix operations to enable AMX acceleration via cblas_sgemm
- Implements online softmax with NEON fast exp approximation for numerical stability
- Achieves ~2x speedup (8.7ms → 4.4ms, 246 → 490 GFLOPS) on the benchmark
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| std::vector<float> Q_f32(seq_len * HEAD_DIM); | ||
| std::vector<float> K_f32(BLOCK_SIZE * HEAD_DIM); | ||
| std::vector<float> V_f32(BLOCK_SIZE * HEAD_DIM); | ||
| std::vector<float> scores(seq_len * BLOCK_SIZE); | ||
| std::vector<float> acc(seq_len * HEAD_DIM); | ||
| std::vector<float> row_max(seq_len); | ||
| std::vector<float> row_sum(seq_len); |
There was a problem hiding this comment.
The std::vector buffers (Q_f32, K_f32, V_f32, scores, acc, row_max, row_sum) are allocated inside the parallel_for lambda, meaning they are allocated and deallocated on every call to this function. For the accelerate path, these buffers are quite large (e.g., Q_f32 is seq_len * 64 floats, which for seq_len=1024 is 256KB). Consider using thread-local storage or a buffer pool to amortize allocation costs across multiple calls, similar to how some other kernels in the codebase might handle this.
| constexpr float NEG_INF = -INFINITY; | ||
|
|
||
| #ifdef __APPLE__ | ||
| if (seq_len >= 64) { |
There was a problem hiding this comment.
The threshold value 64 for seq_len is hardcoded in the conditional. Following the convention established in kernel_matmul.cpp and kernel_conv.cpp, this should be defined as a named constant like ACCELERATE_SEQ_LEN_THRESHOLD at file scope for better maintainability and clarity.
* Accelerate FP16 attention via cblas_sgemm for Apple AMX Batch query positions into real GEMMs so cblas_sgemm can use AMX, same approach as the matmul kernel (#340). Parallelizes over batch * num_q_heads instead of batch * num_q_heads * seq_len, with each work item processing all seq_len positions for one head. Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x) Falls back to existing NEON path for seq_len < 64. Signed-off-by: Kayaan Tharani <[email protected]> * Add ACCELERATE_NEW_LAPACK definition for improved LAPACK support Signed-off-by: HenryNdubuaku <[email protected]> --------- Signed-off-by: Kayaan Tharani <[email protected]> Signed-off-by: HenryNdubuaku <[email protected]> Co-authored-by: HenryNdubuaku <[email protected]>
…te#346) * Accelerate FP16 attention via cblas_sgemm for Apple AMX Batch query positions into real GEMMs so cblas_sgemm can use AMX, same approach as the matmul kernel (cactus-compute#340). Parallelizes over batch * num_q_heads instead of batch * num_q_heads * seq_len, with each work item processing all seq_len positions for one head. Benchmark (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x) Falls back to existing NEON path for seq_len < 64. Signed-off-by: Kayaan Tharani <[email protected]> * Add ACCELERATE_NEW_LAPACK definition for improved LAPACK support Signed-off-by: HenryNdubuaku <[email protected]> --------- Signed-off-by: Kayaan Tharani <[email protected]> Signed-off-by: HenryNdubuaku <[email protected]> Co-authored-by: HenryNdubuaku <[email protected]>
Same approach as #340 (matmul kernel) applied to attention. Batches query positions into GEMMs so cblas_sgemm hits AMX instead of doing 1x64 dot products one at a time.
batch * num_q_heads(one head per work item processes all seq_len positions)Benchmark on M4 Pro (1024x16x64): 8.7ms / 246 GFLOPS -> 4.4ms / 490 GFLOPS (~2x)
Part of #298