Conversation
…g in Gemma4TextModel
Add Gemma 4 Fixes
Add Gemma 4 MoE sanitize and router quantization
|
Hey @Blaizzy — great work getting this up so fast, the full model implementation looks solid (MoE, KV sharing, PLE, all the norm variants). Wanted to flag #1103 by @0xSoftBoi which targets your
Just a friendly heads up :) |
|
Feel free to send a PR to my branch or I can apply cherry pick and apply the fixes here, that way we don't have duplicate work :) |
|
Hey @Blaizzy @angeloskath — great work getting this moving so fast. I've been working on Gemma 4 support in #1103 and wanted to flag a few things I've landed there that might be useful to cherry-pick or integrate here:
All tests passing, pre-commit green, verified inference on 4-bit E2B (~50 tok/s) and LoRA fine-tuning (loss 2.66 → 0.54 over 200 iterations). Happy to submit any of this as a PR against your branch, cherry-pick commits, or help however gets Gemma 4 landed fastest. Let me know what works best. |
|
Hey @Blaizzy — PR is up against your branch: Blaizzy#12 Rebased clean on top of For context — I originally opened #1103 because I was trying to fine-tune Gemma 4 with LoRA and kept hitting the sanitizer bug and Metal buffer crashes. Ended up fixing everything I ran into and figured I'd contribute it back. Appreciate you being open to integrating it! |
|
Hi @0xSoftBoi ! Can you share which devices encountered the 4GB limit. |
|
Feel free to PR the tool parser (and tests) separately as well cause I am about to merge this. Testing the max buffer issue you mentioned on my M2 base. Except for that everything works. |
|
I investigated the issue and it shouldn't happen if you run the quants because all are below the limit.
This will only happen if you run bf16, but if the device already has a 4GB metal buffer limit then it shouldn't run because the peak for bf16 is 11.004 GB.
Read my mind Angelos, that's exactly what I suggested as well :) |
|
Seems ready to merge to me!!! |
angeloskath
left a comment
There was a problem hiding this comment.
Did a bit of cleanup and thorough testing. I will merge as is but I left a couple of question and depending on answers we could remove these two pieces of code.
| # Split the sequence dimension if this still holds too much | ||
| # memory. 260k vocab means the distance tensor would be ~1GB | ||
| # per 2k tokens in bf16. | ||
| # | ||
| # If the embedding is quantized we have to dequantize it anyway to | ||
| # perform the match test. | ||
| norms_embedding = self.embed_tokens.weight.square().sum(-1) | ||
| norms_input = input_embeddings.square().sum(-1) | ||
| distance = _complete_square( | ||
| norms_embedding, | ||
| norms_input, | ||
| self.embed_tokens.as_linear(input_embeddings), | ||
| ) | ||
|
|
||
| # Checks can be added if needed but they necessarily break the GPU | ||
| # pipelining and force an eval. | ||
| # | ||
| # match_counts = (distance < eps).sum(-1) | ||
| # | ||
| input_ids = mx.argmin(distance, -1) |
There was a problem hiding this comment.
This if I remember correctly, I asked N8 on X or the PR he made and he had explained it was for prefill.
But I could be wrong, and happy to take a look at it with fresh eyes.
There was a problem hiding this comment.
its needed in the extreme edge case where you pass input_embeddings to the model and the model needs the input_ids (not embeddings) for later layers token embeddings - so this requires:
input embeddings -> token ids -> per-layer embeddings. i think it would be cleaner to just error - codex recommended it as a fix for this edge case, and it made sense, but i'm not sure its neccessary.
There was a problem hiding this comment.
Yap, exactly. It was input embeddings not prefill 👌🏽
There was a problem hiding this comment.
i think it would be cleaner to just error
Totally. The current implementation is not horrible but if there is no use-case then it is simply dead code. I assumed it was for passing embeddings coming from a different source that doesn't map exactly to input_ids. However the original code was checking for exact match so 🤷♂️
| if mask is not None and isinstance(mask, mx.array): | ||
| if mask.shape[-1] != keys.shape[-2]: | ||
| mask = mask[..., -keys.shape[-2] :] |
There was a problem hiding this comment.
Similarly when does this ever happen? What prompted you to add that?
There was a problem hiding this comment.
This is for sliding windows.
But I guess it can be done differently.
I wish they had separate the models into 2-3 types for clarity.
Cherry-picked from upstream/main 4469ad4 onto b1-replay branch. Resolved merge conflict in BatchRotatingKVCache.merge() — adopted upstream's bounds-safe length slicing with loop variable unpacking. Co-authored-by: Prince Canuma <[email protected]> Co-authored-by: N8 <[email protected]> Co-authored-by: Angelos Katharopoulos <[email protected]>
User hit 'Model type gemma4 not supported' on the agent_runner subprocess startup after switching to Gemma 4 E4B as the default. Root cause: mlx-lm 0.31.1 (what voksly was running) doesn't have the gemma4 model module. Gemma 4 support was added in 0.31.2, released April 7 (ml-explore/mlx-lm#1093). Blocker for the straightforward upgrade: mlx-audio 0.4.2 hard-pins mlx-lm == 0.31.1 in its dependency spec, so a normal pyproject bump fails resolution with 'mlx-audio depends on mlx-lm == 0.31.1 but your project depends on mlx-lm >= 0.31.2'. Workaround: uv override-dependencies. Added a [tool.uv] block in pyproject.toml that forces mlx-lm >= 0.31.2 across the resolver, overriding mlx-audio's transitive pin. Verified end-to-end: $ uv sync - mlx-lm==0.31.1 + mlx-lm==0.31.2 $ uv run python -c '...' mlx_lm: 0.31.2 gemma4 module: ok ← gemma4 model class now imports mlx_audio: ok ← mlx-audio still loads with newer mlx-lm voksly imports: ok ← no breakage in our code path The mlx-audio pin was overly conservative; the TTS code path uses nothing that changed between mlx-lm 0.31.1 and 0.31.2. With the upgrade in place, flipped the default back to Gemma 4 E4B in all three default-resolution paths: - voksly.llm.DEFAULT_MLX_MODEL_KEY = 'gemma-4-e4b' - voksly.config.settings.llm.model = 'mlx-community/gemma-4-e4b-it-4bit' - frontdesk LLM dropdown first option = Gemma 4 E4B And reordered KNOWN_MLX_MODELS so Gemma 4 E4B is iterated first (the dropdown UI follows dict insertion order). Drop the [tool.uv] override block when mlx-audio releases a version with a relaxed mlx-lm constraint. Documented in pyproject.toml comments.
No description provided.