Skip to content

optimize scale correction in cactus_attention_f16_h64#485

Merged
HenryNdubuaku merged 1 commit intomainfrom
stabilize_attention
Mar 4, 2026
Merged

optimize scale correction in cactus_attention_f16_h64#485
HenryNdubuaku merged 1 commit intomainfrom
stabilize_attention

Conversation

@jakmro
Copy link
Copy Markdown
Collaborator

@jakmro jakmro commented Mar 3, 2026

No description provided.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the online softmax scale correction in cactus_attention_f16_h64, the specialised attention kernel for 64-dimensional heads. Instead of unconditionally rescaling all 16 NEON accumulator vectors every block, the new code branches: when the current block raises the running maximum it rescales the prior state (as before), and when it does not it defers the scale factor to a cheap per-element multiply at accumulation time — saving 16 vmulq_n_f32 calls in the common case. A stale #include <iostream> in kernel_nn.cpp is also cleaned up.

Changes:

  • kernel_attention.cpp: Refactors the scale-correction loop in cactus_attention_f16_h64 to skip rescaling the 16 NEON accumulators when the current block's maximum does not exceed the running maximum, applying a scalar current_block_scale to the block's contribution instead.
  • kernel_nn.cpp: Removes unused #include <iostream>.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
cactus/kernel/kernel_attention.cpp Core optimization: conditional accumulator rescaling in cactus_attention_f16_h64
cactus/kernel/kernel_nn.cpp Removes leftover unused #include <iostream>

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 305 to +341
@@ -318,11 +324,11 @@ static inline void cactus_attention_f16_h64(
}

for (size_t i = 0; i < kv1 - kv0; i++) {
float w = block_scores[i];
if (w == 0.f) continue;
const float attn_weight = block_scores[i] * current_block_scale;
if (attn_weight == 0.f) continue;

const __fp16* v = values + batch*kv_batch_stride + (kv0+i)*kv_seq_stride + kv_head*HEAD_DIM;
float32x4_t wv = vdupq_n_f32(w);
float32x4_t wv = vdupq_n_f32(attn_weight);

#pragma unroll
for (int d = 0; d < 8; d++) {
@@ -332,8 +338,7 @@ static inline void cactus_attention_f16_h64(
}
}

running_sum += block_sum;
running_max = block_max;
running_sum += block_sum * current_block_scale;
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modified cactus_attention_f16_h64 function is exercised only when head_dim == 64, mask == nullptr, and window_size == 0 (see kernel_attention.cpp:387–394). However, the existing attention test in tests/test_kernel.cpp:187–213 uses head_dim = 8, and tests/test_graph.cpp:223–242 uses head_dim = 4 — neither covers this specialised path.

The optimization changes the flow into two distinct branches depending on whether block_max > running_max or not. Both branches need to be exercised to verify correctness. With BLOCK_SIZE=32 and typical inputs, the "block_max <= running_max" branch (the new else path with current_block_scale = expf(block_max - running_max)) is only reachable after a later block has a lower max than an earlier one. Without a numerical correctness test against a known reference using head_dim=64 (which would trigger the h64 path), it is hard to be confident the optimization is bug-free in all scenarios.

A test that:

  1. Uses head_dim=64 (to trigger cactus_attention_f16_h64),
  2. Has enough KV positions to span multiple blocks (e.g., > 32), and
  3. Compares results with a reference implementation (e.g., the general path via head_dim != 64)

would give confidence in the correctness of both new branches.

Copilot uses AI. Check for mistakes.
@HenryNdubuaku HenryNdubuaku merged commit fcf0ea9 into main Mar 4, 2026
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants