fix greedy sampler ignoring logit suppression#486
Conversation
Signed-off-by: jakmro <[email protected]>
There was a problem hiding this comment.
Pull request overview
Fixes Whisper decoding to respect token-suppression via logit bias during greedy sampling, and updates model publishing/docs to include additional Whisper variants.
Changes:
- Apply per-step logit bias (token suppression) when sampling the next token in
WhisperModel::decode_with_audio. - Adjust greedy-sampling implementations so the argmax is taken over bias-adjusted logits.
- Add
openai/whisper-tinyandopenai/whisper-baseto the supported/published model lists.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
cactus/models/model_whisper.cpp |
Initializes suppression bias maps and passes them into CactusGraph::sample() (first-step vs later steps). |
cactus/models/model.h |
Adds cached suppression-bias maps to WhisperModel. |
cactus/kernel/kernel_nn.cpp |
Makes greedy sampling respect bias-adjusted logits for FP32/FP16 paths (and refactors argmax accordingly). |
README.md |
Documents additional Whisper model variants. |
.github/workflows/publish_to_hf.yml |
Adds Whisper tiny/base to the Hugging Face publish matrix/config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -546,6 +528,16 @@ void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size, | |||
| } | |||
| } | |||
|
|
|||
| if (temperature == 0.0f && top_p <= 0.0f && top_k == 0) { | |||
| if (vocab_size == 0) { | |||
| output[0] = 0; | |||
| return; | |||
| } | |||
| auto it = std::max_element(filtered_logits.begin(), filtered_logits.end()); | |||
| output[0] = static_cast<uint32_t>(std::distance(filtered_logits.begin(), it)); | |||
| return; | |||
There was a problem hiding this comment.
Greedy sampling in cactus_sample_f32 now always allocates/copies a full filtered_logits vector even when temperature==0 && top_p<=0 && top_k==0. Since greedy decoding is common (and vocab is large), this can be a noticeable regression. Consider a fast-path that computes the argmax directly from logits while applying only the provided sparse bias, avoiding the full allocation/copy.
| std::vector<float> biased(vocab_size); | ||
| for (size_t i = 0; i < vocab_size; ++i) { | ||
| biased[i] = static_cast<float>(logits[i]); | ||
| } | ||
| if (bias_values && bias_indices && bias_count > 0) { | ||
| for (size_t i = 0; i < bias_count; ++i) { | ||
| if (bias_indices[i] < vocab_size) { | ||
| biased[bias_indices[i]] += bias_values[i]; | ||
| } | ||
| } | ||
| } | ||
| output[0] = static_cast<uint32_t>(best_idx); | ||
| auto it = std::max_element(biased.begin(), biased.end()); | ||
| output[0] = static_cast<uint32_t>(std::distance(biased.begin(), it)); |
There was a problem hiding this comment.
The FP16 greedy path allocates and fills a std::vector<float> of size vocab_size on every call just to compute an argmax with bias. If this runs per-token during decoding, it can be expensive; consider computing the biased argmax without materializing a full float buffer (e.g., scan logits and incorporate a sparse bias lookup structure).
| const float neg_inf = -std::numeric_limits<float>::infinity(); | ||
| for (size_t tok : suppress_tokens_) { | ||
| suppress_bias_[static_cast<uint32_t>(tok)] = neg_inf; | ||
| } | ||
| suppress_bias_first_step_ = suppress_bias_; | ||
| for (size_t tok : begin_suppress_tokens_) { | ||
| suppress_bias_first_step_[static_cast<uint32_t>(tok)] = neg_inf; | ||
| } |
There was a problem hiding this comment.
std::numeric_limits<float> is used here but this translation unit doesn’t include <limits> (and it’s not pulled in by the current includes). Please add #include <limits> to avoid a build break / reliance on transitive headers.
| std::vector<float> biased(vocab_size); | ||
| for (size_t i = 0; i < vocab_size; ++i) { | ||
| biased[i] = static_cast<float>(logits[i]); | ||
| } | ||
| if (bias_values && bias_indices && bias_count > 0) { |
There was a problem hiding this comment.
In cactus_sample_f16, bias handling is now added for the greedy path, but the non-greedy path still effectively drops bias_values when applying temperature scaling / the else branch (it overwrites filtered_logits from logits). This means logit suppression/bias won’t work for FP16 sampling when temperature > 0 and/or top_p/top_k are used; please ensure scaling and copies operate on the biased filtered_logits (or re-apply bias after scaling).
No description provided.