Skip to content

Kernel_attention optimization #319

Merged
HenryNdubuaku merged 2 commits intocactus-compute:mainfrom
Ayan9074:kernel_optimizations
Feb 9, 2026
Merged

Kernel_attention optimization #319
HenryNdubuaku merged 2 commits intocactus-compute:mainfrom
Ayan9074:kernel_optimizations

Conversation

@Ayan9074
Copy link
Copy Markdown
Contributor

@Ayan9074 Ayan9074 commented Feb 3, 2026

Description

Adds a specialized FP16 attention kernel for head_dim == 64, inspired by Qualcomm Hexagon / HVX optimization principles.
Instead of relying on the generic attention path (which handles arbitrary head sizes, masks, and windowing), this change introduces a fixed-shape h64 fast path that:

  • Uses compile-time constants for head dimension (64)
  • Uses statically allocated accumulators (no dynamic vectors)
  • Unrolls the inner dot-product loop
  • Applies block-wise numerically stable softmax
  • Avoids mask / window logic on decode path

Performance impact tested (M4 MacBook Air) (Average across 10 runs):
Attention 1024×16×64:
Before: ~19.7 ms (~109 GFLOPS)
After: ~13.0 ms (~165 GFLOPS)

Limitations:

  • The h64 kernel does not support masks or sliding window attention - falls back to normal version
  • This path is intended specifically where:
    • head_dim == 64
    • No attention mask is present
    • No windowed attention is used

Type of Change

  • Bug fix
  • New feature
  • Performance improvement
  • Documentation update

Testing

  • Tests pass locally
  • Tested on ARM hardware
  • Benchmarked performance impact

Checklist

  • All commits are signed-off (DCO)
  • Code follows project style
  • Comments added where necessary
  • Documentation updated if needed

I'm not sure if this is what was being looked for with issue #300 and I haven't updated any documentation(not sure if needed). Is it needed to change this code to adapt for mask and window_size? Should I add this change to cactus_attention_hybrid_int8_fp16 also? Should I just be changing cactus_attention_f16 and not create an extra func for h64?

@HenryNdubuaku
Copy link
Copy Markdown
Collaborator

Your head in the right space, does any model benefit from this? Read their attention kernels to see. Also, have you had a look at the HVX issue?

@Ayan9074
Copy link
Copy Markdown
Contributor Author

Ayan9074 commented Feb 4, 2026

yeah some models benefit: FP16 ones like whisper/moonshine/nomic/siglip2, and prefill for gemma/qwen/lfm2 too. but decode for gemma/qwen/lfm2 uses int8 hybrid so this new path won't be used unless I add something for cactus_attention_hybrid_int8_fp16. Also almost all models head_dim is set to 64 so they should be used, and I believe the only model that explicitly passes a window size greater than 0 is gemma but most models still have sliding KV cache by default. Not sure what HVX issue refers to?

@HenryNdubuaku
Copy link
Copy Markdown
Collaborator

great @Ayan9074 do you have benchmark numbers before and after?

@Ayan9074
Copy link
Copy Markdown
Contributor Author

Ayan9074 commented Feb 5, 2026

All tests on macbook air m4:

In terms of the perfomance benchmarks on tests:
Attention 1024×16×64:
Before: ~19.7 ms (~109 GFLOPS)
After: ~13.0 ms (~165 GFLOPS)

But since it only really hits audio models, the main impact right now is just for example decode tps rising on my laptop for test like PCM buffer transcription and transcription by around 10% and the total time in ms for the tests to run falls by around 100ms on average. For the longer stream transcription test the total time falls from around 7400ms to 6700ms.

I will work on adapting this so it also affects chat models.

@HenryNdubuaku
Copy link
Copy Markdown
Collaborator

@Ayan9074 remember to update me once this is ready

@Ayan9074
Copy link
Copy Markdown
Contributor Author

Ayan9074 commented Feb 8, 2026

@HenryNdubuaku Improvements do occur for Whisper and other FP16 audio models (numbers above), so this PR is ready I believe.

For text models, they use the hybrid INT8 path; I tested an h64 hybrid fast path as well, but decode is dominated by INT8 matmul / memory bandwidth, so it didn’t show meaningful gains. I will keep working on optimizing the attention kernel for the next week to see if there are any other gains to be made, either through SME2 or AMX or HVX style implementations.

Tested on m4 MacBook Air:

Attention 1024×16×64:
Before: ~19.7 ms (~109 GFLOPS)
After: ~13.0 ms (~165 GFLOPS)

Decode tps rising on my laptop for test like PCM buffer transcription and transcription by around 10% and the total time in ms for the tests to run falls by around 100ms on average.
For the longer stream transcription test the total time falls from around 7400ms to 6600ms.

@HenryNdubuaku HenryNdubuaku merged commit 40309fa into cactus-compute:main Feb 9, 2026
1 of 2 checks passed
ncylich pushed a commit that referenced this pull request Feb 24, 2026
* cactus attnention hexagon optimization for h64

Signed-off-by: Ayan9074 <[email protected]>

* h64 implementation for cactus_attention_f16

Signed-off-by: Ayan9074 <[email protected]>

---------

Signed-off-by: Ayan9074 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants