Conversation
This commit introduces several Metal kernel functions for TurboQuant, including: - `_mse_score_kernel` - `_pack_lowbit_kernel` - `_unpack_lowbit_kernel` - `_qjl_score_kernel` - `_prod_score_kernel` Additionally, the `scaled_dot_product_attention` function in `base.py` is updated to handle a special case for single-query inputs, allowing for efficient decoding from the TurboQuant KV cache. New tests are added to ensure the correctness of the TurboQuant decoding process and its performance against dequantized attention outputs.
… (64 tok/s) This commit introduces new Metal kernel functions, including: - `_prod_score_multi_kernel` - `_mse_weighted_rot_multi_kernel` - `_prod_score_repeat_kernel` These kernels enhance the TurboQuant functionality by optimizing multi-head processing and residual norm calculations. Additionally, a new test is added to verify the integer TurboQuant decode fast path, ensuring it correctly bypasses unnecessary preparation steps. This improves performance and correctness in TurboQuant operations.
This commit introduces two new Metal kernel functions: `_polar_prod_score_kernel` and `_polar_turbo_score_repeat_kernel`, enhancing the TurboQuant functionality for 4-bit operations. Additionally, new tests are added to verify the correct behavior of TurboQuant decoding paths, ensuring performance optimizations and correctness in handling 4-bit quantization. This update improves the overall efficiency of TurboQuant operations and expands its capabilities.
Replace broken Hadamard transform with proper QR-based Haar
rotation (used by all reference implementations). Fix shape
handling for arbitrary batch dims.
Refs: tonbistudio/turboquant-pytorch, TheTom/turboquant_plus,
Blaizzy/mlx-vlm#858
Signed-off-by: lishunyang <[email protected]>
|
I ran a mini eval that tests perplexity (WikiText-2) and passkey retrieval on qwen 2.5 3B(By accident) and 3.5 4B and noticed something interesting TurboQuant works great for hybrid attention models when a small amount of the layers use KV cache, but causes huge losses when all layers are quantized (The key takeaway is don't use this for all models)
I think quantizing only the layers the user chooses as an optional parameter (Starting with the ones in the middle since the first and last seem to be the most impactful) could solve it. It did! Here are the results of running several tests: Layer-Selective TurboQuant Results — Qwen2.5-3B (36 attn layers, 4-bit)
Skipping just the first and last 2 layers takes PPL from horrible (+243) to (+0.93), is not as good as 3.5, but is still quantizing 89% of layers. Which is a lot more. Maybe the best solution is to define the % of layers we want to do with a recommendation? |
|
Hey @kikoncuo Thanks for reaching out and the report! I actually tested with full attention models and fixed the bug in another branch. This one was optimized for hybrid attn models like Qwen3.5-4B and beyond. I will merge that branch here once I get the perf I'm looking for. |
|
Was the bug related to quantizing all layers vs a few (Due to the compound impact)? or was it something else? Ping me if you want me to run tests again. I'm also calculating the decode speed, I'm not showing it because I'm on the outdated slower version. This is super cool!! |
|
It was related to defaults and approach that favours/assumes a small number of layers are full attn |
|
I see, the problem is related to having the same rotation matrix to all layers, so errors were accumulating across all layers, that's why when more layers were impacted, the error grew linearly (e1+e2...+e36=36e). If they are pointing to different directions they should almost cancel each other (Claude says it's sqrt(36e)). This means that the performance of 3.5 will also improve with your new version. (Small delta because is already too good I imagine) This also means that the memory savings will be a lot bigger in non hybrid models (performance will still be a bit worse since there are way more layers maybe we want to make it configurable). Still, not quantizing the first 2 and last 2 layers (Or doing them last) was super effective on reducing the error, if we ever want to make the number or % of layers to quant, that's something we'll likely want to add. I can create a quick PR when you have the new version. |
…TurboQuant This commit updates the `_reserve_state_capacity` function to round up the new capacity to the next step boundary, preventing excessive growth. Additionally, the `quantize` methods in both `_TurboQuantMSECodec` and `_TurboQuantProdCodec` classes are simplified by directly using the maximum of norms and a small epsilon value for unit vector normalization, enhancing code clarity and performance. Cached state handling in `TurboQuantKVCache` is also improved to optimize state retrieval and management.
…n logic This commit modifies the data type for norms and residual norms in the `_TurboQuantMSECodec`, `_TurboQuantPolarProdCodec`, and `_TurboQuantProdCodec` classes to `mx.float16`, enhancing memory efficiency. Additionally, the evaluation logic in `TurboQuantKVCache` is refined to only trigger during multi-token decoding, preventing unnecessary computation graph buildup during single-token steps. These changes improve performance and resource management in TurboQuant operations.
…ze to minimize overhead. This change allows for processing all tokens in one pass, effectively eliminating chunking and improving memory management during decoding.
…for improved efficiency. This update modifies the `_mse_scores_weighted_rot_repeat_kernel` and `_mse_scores_weighted_rot_sum_repeat_kernel` functions to utilize max scores directly, reducing the need for separate token-dimension passes. Additionally, adjustments are made in the `_metal_mse_weighted_sum_from_scores` function to accommodate the new max scores input, optimizing the overall computation process.
This commit introduces a new function, `estimate_kv_size_gb`, to calculate the estimated size of the key-value (KV) cache based on model configuration and token count. The function is integrated into the `run_single` and `run_multi` methods, allowing for the reporting of KV cache size in the test results. Additionally, the output format is updated to include KV size in the printed summaries, enhancing the visibility of memory usage during tests.
This commit refactors the `score_prepared` method in the `_SplitCodec` class to launch both sub-codec scores concurrently before synchronization, improving efficiency. Additionally, comments are added to clarify the purpose of launching operations before concatenation in the `weighted_sum_from_scores` and `weighted_sum_stats_from_scores` methods, further enhancing performance through GPU overlap.
Major optimization: replace 4 separate Metal kernel dispatches (2 score + 2 weighted_sum for SplitCodec) with a single fused kernel that tiles both token and value dimensions. Per-layer decode at 128k: 2.17ms (3.3x SDPA), down from 5.84ms (9x SDPA) — a 2.7x speedup. Kernel architecture: - Grid: (32, num_val_tiles, B*H*num_tok_tiles) for massive GPU parallelism - 32 SIMD lanes cooperate on key scoring (simd_sum across dims), then each lane accumulates its own value dim with online softmax - Token tiling (1024 tokens/tile) gives 128 threadgroups at 128k context, enough to saturate the GPU and hide memory latency - Cross-tile online-softmax reduction merges partial results Additional optimizations in this commit: - Float16 norms storage (was float32) — saves 9% KV memory - Conditional mx.eval in update_and_fetch (skip for single-token decode) - Single-pass weighted_sum kernels (precomputed max scores) - Removed decode chunking (was 65536 limit) - State property caching to avoid re-slicing - Conservative capacity growth in _reserve_state_capacity - QuantizedStateProxy for model compatibility (keys.shape access) - Skip RotatingKVCache in quantize_entry (sliding window already compact) - Pre-convert empty KVCache to TurboQuantKVCache before prefill - NIAH runner: reports active memory + KV size, generates heatmaps - PPL runner: uses generate.py functions, sliding window evaluation - NIAH dataset generator: single + multi needle across 2k-256k Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
This commit updates the input handling in several Metal kernel functions by removing the conversion of state indices and signs to `mx.uint32`. This change simplifies the code and improves performance by directly using the original data types. Additionally, a new fused integer decode kernel is introduced, optimizing the scoring and weighted sum processes for TurboQuant operations, which enhances overall efficiency.
This commit introduces a combined query transformation for the TurboQuant decoding process, optimizing the handling of queries by replacing multiple matrix operations with a single matrix multiplication. This change reduces computational overhead and improves performance. Additionally, the fused integer decode kernel is further refined to streamline output handling, ensuring efficient accumulation of results across multiple dimensions. Comments are added for clarity on the new structure and its benefits.
This commit introduces a new single-tile fused integer decode kernel in the TurboQuant module, optimizing the decoding process for long contexts by reducing key read redundancy and improving memory bandwidth utilization. The kernel is designed to handle multiple value dimensions per lane, enhancing performance in scenarios with high token counts. Additionally, the token tile size is adjusted based on context length to further minimize cross-tile overhead. Comments are added for clarity on the new implementation and its benefits.
…e from split kernel for 3.5 This commit updates the `_fused_integer_decode_kernel` and `_fused_integer_decode_single_tile_kernel` functions to accept an additional `key_mse_bits` parameter, allowing for more flexible handling of key and value bit configurations. The `_ensure_codecs` method is also modified to compute key and value bits based on the fractional nature of the `self.bits` attribute, optimizing codec initialization. These changes improve the efficiency and adaptability of the TurboQuant decoding process.
…roved performance This commit enhances the `_TurboQuantMSECodec` class by introducing precomputed midpoints for fast comparison-based quantization, significantly reducing computational complexity. The quantization process is optimized to use midpoints instead of broadcasting, resulting in a performance improvement of 11-28x. Additionally, comments are added to clarify the benefits of these changes. The TurboQuantKVCache class is also updated to refine conditions for using single-tile kernels based on value tile redundancy and context length.
…code method in TurboQuantKVCache (2x speedup) This commit introduces a new `_single_tile_value_weighted_sum_kernel` function, optimizing the value weighted sum process for large dimensions by leveraging a single-tile approach. Additionally, a `_separate_score_value_decode` method is added to the `TurboQuantKVCache` class, which combines fast key scoring with the new weighted sum kernel, achieving significant performance improvements over the previous fused kernel approach. The implementation includes detailed comments for clarity on the new functionality and its benefits.
…ce and clarity This commit refines the `_single_tile_value_weighted_sum_kernel` function by optimizing variable handling and enhancing comments for better understanding. The changes include a shift from handling multiple value dimensions per lane to a more efficient single value dimension approach, significantly improving latency hiding. Additionally, the grid and threadgroup configurations are updated to maximize occupancy, further boosting performance. These adjustments contribute to a more streamlined and efficient kernel implementation.
… softmax weights for enhanced performance This commit updates the `_single_tile_value_weighted_sum_kernel` function to leverage precomputed softmax weights, eliminating the need for exponential calculations in the inner loop. The changes improve the kernel's efficiency, achieving a 2x speedup compared to the previous implementation. Additionally, the input and output handling has been streamlined, and comments have been enhanced for better clarity on the new functionality.
…refill path - Guard Metal kernel fast path with L==1 check in weighted_sum_stats_from_scores to prevent reshape crash with multi-query (L>1) prefill tensors - Keep dequantize+SDPA for prefill (faster than quantized_attention at all sizes due to einsum fallback being slower than SDPA for multi-query) - Verified: quantized_attention is 1.6x slower than dequantize+SDPA for prefill across all cache sizes (2k-64k) due to chunked einsum overhead Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
…, simplifying the attention mechanism by relying solely on dequantization and SDPA. This change enhances clarity and maintains performance consistency.
- Multi-query score kernel: unpack key data ONCE per token, loop over L query positions. Avoids R*L repeat explosion (28ms→12ms at 64k). - prefill_attention: uses MQ score + precomputed softmax + TG=D value kernel for prefill, avoiding O(T*D²) dequantize rotation matmul. - Falls back to dequantize+SDPA when MQ path not available. - 1.9x faster per-layer prefill attention at small cache sizes. - Fix weighted_sum_stats_from_scores L>1 reshape crash. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
…ility - Updated the _beta_pdf function to utilize lgamma for coefficient calculation, preventing overflow issues with large dimensions. - Introduced log_pdf for normalization, enhancing performance and stability when computing the probability density function. - Adjusted clipping to ensure numerical safety during calculations.
Summary
This PR adds TurboQuant, a compression method that achieves a high reduction in model size with zero accuracy loss, making it ideal for supporting both key-value (KV) cache compression and vector search. It accomplishes this via two key steps:
High-quality compression (the PolarQuant method): TurboQuant starts by randomly rotating the data vectors. This clever step simplifies the data's geometry, making it easy to apply a standard, high-quality quantizer (a tool that maps a large set of continuous values, like precise decimals, to a smaller, discrete set of symbols or numbers, like integers: examples include audio quantization and jpeg compression) to each part of the vector individually. This first stage uses most of the compression power (the majority of the bits) to capture the main concept and strength of the original vector.
Eliminating hidden errors: TurboQuant uses a small, residual amount of compression power (just 1 bit) to apply the QJL algorithm to the tiny amount of error left over from the first stage. The QJL stage acts as a mathematical error-checker that eliminates bias, leading to a more accurate attention score.
Source:
https://research.google/blog/turboquant-redefining-ai-efficiency-with-extreme-compression/
Results
Model: Qwen3.5-35B-A3B (bf16)

Device: M3 Max 96GB
Context: 8k, 32k, 64k
Test: Needle-in-a-haystack style
Example
Note⚠️ : This is implementation is far from optimal, I'm still working on improving it to the claimed speedup results, if you have improvement suggests feel free to open a PR against this branch. In particular, I don't see the prefill and decode performance matching up to the claimed 8x speed up.