-
Notifications
You must be signed in to change notification settings - Fork 166
Faster FlashMLA prompt processing #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current MLA implementation computes wv_b * (k_cache * softmax(k_cache * (wk_b*q))) This leads to 3.4X more multiply-adds (madds) compared to standard attention. Due to the resulting tensor shapes, TG is still faster than standard attention because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q))) multiplications become GEMMs, so the additional madds are more than compensated for due to the much higher performance of GEMMs compared to GEMVs. But for PP, where we are dealing with GEMMs in both cases, the additional madds needed for MLA lead to lower performance, with the performance gap increasing with context length. So, then, when we are dealing with PP, we can rearrange the above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q), thus transforming it into the standard attention mechanism. We do need two additional matrix multiplications (which in practice is done as a single wkv_b * k_cache GEMM) with the *entire* K cache. But this is still cheaper than MLA, as we end up with 1.8X the madds required by standard attention. Oh, these figures are for the DeepSeek-V3/R1/Lite attention architecture. This leads to a significant PP performance increase compared to standard MLA with FA. There are many upsides to this: * If we only apply the above trick when we are processing more than X tokens (with suitable chosen X), TG performance stays the same as MLA with FA * We still need to store just the K-cache, so 576 entries per layer for DeepSeek-V3/R1/Lite * We get significantly better PP performance * We can use MLA+FA on CUDA. It works already with this commit for PP, something is not yet quite right for TG. The downside is that it only works with fp16 cache (for now). This is so because we need to convert the cache to fp32, else we cannot do the wkv_b * k_cache matrix multiplication (which in ggml requires the second operand to be fp32). But converting (copying) to fp32 only works for f16, bf16 and f32 tensors, so no luck with quantized cache. Another reason that we need to convert to fp32 is that the cache contains the RoPE'd portion, which we need to concatenate to the result of the wkv_b * k_cache matrix multiplication. Also this op works only when the tensors being concatenated are both fp32. So much about ggml being a general purpose ML library.
except for q8_KV (q8_KV has row meta data, and there is still some confusion with row sizes because of that).
|
Getting a linking error on
|
|
Are you using |
Ah, I think that'll fix it. I was using the |
MLA as used in the DeepSeek models is great for token generation (TG), but prompt processing (PP) speed is much lower compared to standard attention even with FA enabled.
This PR improves FlashMLA speed by a large margin. FlashMLA is CPU only, but the PR paves the way to perhaps also get it on CUDA (but this is left for a future PR).
The following table compares FlashMLA PP speed for DeepSeek-Lite quantized as
IQ4_NLbetween the main branch and this PR. CPU is Ryzen-7950X, the cache is quantized withQ8_0,fmoeis on.For reference, here is a comparison between standard attention with FA enabled and FlashMLA with this PR
I.e., still quite a bit slower than standard attention with FA enabled for long contexts, but much better than the original implementation.
The new functionality is enabled via
-mla 2 -faas command line arguments. I know, it is getting confusing, so here is a summary of what happens with the differentmlaandfacombinations:mla = 0, fa = 0: standard attention without FA. Works on the CPU and on CUDA. Large K- and V-cache required. The V cache cannot be quantizedmla = 0, fa = 1: standard attention with FA. Works on the CPU and on CUDA. Large K- and V-cache required. The V cache can be quantized. Best PP performance, TG performance is slightly lower than standard attention without FAmla = 1, fa = 0: MLA attention. Works on the CPU and on CUDA. Smaller K- and smaller transposed V cache required. The V cache cannot be quantized. Great TG performance, pathetic TG performance.mla = 1, fa = 1: FlashMLA. Works only on the CPU. Only small K cache required. Great TG performance, slightly less pathetic PP performancemla = 2, fa = 0: FlashMLA . Works only on the CPU and on CUDA. Only small K cache required (the transposed V cache is computed on the fly). Great TG performance (but slightly lower thanmla = 1for long contexts), pathetic PP performance.mla = 2, fa = 1: FlashMLA from this PR. Works only on CPU. Only small K cache required. Great TG performance, more acceptable PP performance.Background
Let$X$ and $Q$ be the activations and the query after projection with their corresponding MQA tensors and after applying rotational position encoding (RoPE). In standard attention one computes (apart from scaling factors and masks that I'll omit for simplicity)
In practice the$W_k$ and $W_v$ tensors are combined into $W_{kv}$ (the tensor $Y = W_{kv} X$ , and the tensors $K$ and $V$ are views into $Y$ . The matrix multiplication with $W_{kv}$ is performed only for the tokens in the batch being processed, the results are stored in the cache, and the tensors $V_{\rm cache}$ and $K_{\rm cache}$ are views into the KV cache.
wkv_binllama.cpp), one computesWith MLA one computes
where one stores$X$ directly into the K-cache, and $K_{\rm cache}$ is an appropriate view into the cache. $V_{\rm cache}$ is a transposed version of $K_{\rm cache}$ with FA is not used, or a slightly different view into the K-cache with FA or $N$ , for the DeepSeek models the number of madds with MLA is $(576 + 512)/(192 + 128) \times N = 3.4 \times N$ . Why is TG with MLA faster than with standard attention if one needs to do more computation? The difference comes from the shapes of the various matrices involved. TG with standard attention results in the tensor $Q$ being of shape $M \times 1 \times L$ , so all multiplications are matrix-vector (a.k.a. GEMV), which are memory bound on basically any modern system (CPU or GPU). With MLA the shape of $Q'$ is $M' \times L$ , so the calculation involves matrix-matrix multiplications (a.k.a. GEMM), which are much faster per madd, so one ends up with a better performance despite having computed more madds. But for PP in both cases we are dealing with GEMMs, so the
mla=2. The benefit of doing this reordering of the operations is that the cache becomes much smaller. But as these are not square matrices, the amount of multiply-adds (madds in the following) does depend on the order of the matrix multiplications. If we denote the number of madds in the standard attention implementation wit3.4Xmore madds makes MLA PP processing slower. As an example, for 8k tokens with standard attention and FA, about 25% of the time is spent in the flash attention computation. We can estimate the expected MLA PP performance to be0.75 + 0.25 x 3.4 = 1.6times slower. From the above tables we see that in practice it is515 t/s / 293 t/s = 1.75times slower. As there are some other differences in the performed matrix multiplications, our back-of-the-envelope estimate comes quite close to the observed behavior.So, how can we improve? We can rearrange the computation back to standard attention. The only difference: as we are storing$X$ into the cache, we need to multiply $W_{kv}$ with the entire content of the cache. This seems pretty stupid at first glance (and I had had the idea to rearrange the multiplications quite a while ago but discarded it because of that), but if one sits down and counts the actual madds that are required, one finds that for DeepSeek this results in $(192 + 3 \times 128)/(192 + 128) = 1.8 \times N$ more madds than standard attention. I.e., we still need more madds, but significantly less madds than the existing MLA implementation. What about TG? We save the day by applying the rearranged matrix multiplications only if the number of tokens in the batch is greater than 1 (or some suitably chosen threshold). In this way we keep the good TG performance, keep the reduced cache size, and get improved prompt processing speed.