optimize scale correction in cactus_attention_f16_h64#485
Conversation
Signed-off-by: jakmro <[email protected]>
There was a problem hiding this comment.
Pull request overview
This PR optimizes the online softmax scale correction in cactus_attention_f16_h64, the specialised attention kernel for 64-dimensional heads. Instead of unconditionally rescaling all 16 NEON accumulator vectors every block, the new code branches: when the current block raises the running maximum it rescales the prior state (as before), and when it does not it defers the scale factor to a cheap per-element multiply at accumulation time — saving 16 vmulq_n_f32 calls in the common case. A stale #include <iostream> in kernel_nn.cpp is also cleaned up.
Changes:
kernel_attention.cpp: Refactors the scale-correction loop incactus_attention_f16_h64to skip rescaling the 16 NEON accumulators when the current block's maximum does not exceed the running maximum, applying a scalarcurrent_block_scaleto the block's contribution instead.kernel_nn.cpp: Removes unused#include <iostream>.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
cactus/kernel/kernel_attention.cpp |
Core optimization: conditional accumulator rescaling in cactus_attention_f16_h64 |
cactus/kernel/kernel_nn.cpp |
Removes leftover unused #include <iostream> |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -318,11 +324,11 @@ static inline void cactus_attention_f16_h64( | |||
| } | |||
|
|
|||
| for (size_t i = 0; i < kv1 - kv0; i++) { | |||
| float w = block_scores[i]; | |||
| if (w == 0.f) continue; | |||
| const float attn_weight = block_scores[i] * current_block_scale; | |||
| if (attn_weight == 0.f) continue; | |||
|
|
|||
| const __fp16* v = values + batch*kv_batch_stride + (kv0+i)*kv_seq_stride + kv_head*HEAD_DIM; | |||
| float32x4_t wv = vdupq_n_f32(w); | |||
| float32x4_t wv = vdupq_n_f32(attn_weight); | |||
|
|
|||
| #pragma unroll | |||
| for (int d = 0; d < 8; d++) { | |||
| @@ -332,8 +338,7 @@ static inline void cactus_attention_f16_h64( | |||
| } | |||
| } | |||
|
|
|||
| running_sum += block_sum; | |||
| running_max = block_max; | |||
| running_sum += block_sum * current_block_scale; | |||
There was a problem hiding this comment.
The modified cactus_attention_f16_h64 function is exercised only when head_dim == 64, mask == nullptr, and window_size == 0 (see kernel_attention.cpp:387–394). However, the existing attention test in tests/test_kernel.cpp:187–213 uses head_dim = 8, and tests/test_graph.cpp:223–242 uses head_dim = 4 — neither covers this specialised path.
The optimization changes the flow into two distinct branches depending on whether block_max > running_max or not. Both branches need to be exercised to verify correctness. With BLOCK_SIZE=32 and typical inputs, the "block_max <= running_max" branch (the new else path with current_block_scale = expf(block_max - running_max)) is only reachable after a later block has a lower max than an earlier one. Without a numerical correctness test against a known reference using head_dim=64 (which would trigger the h64 path), it is hard to be confident the optimization is bug-free in all scenarios.
A test that:
- Uses
head_dim=64(to triggercactus_attention_f16_h64), - Has enough KV positions to span multiple blocks (e.g., > 32), and
- Compares results with a reference implementation (e.g., the general path via
head_dim != 64)
would give confidence in the correctness of both new branches.
No description provided.