Skip to content

Add gemma 4#1093

Merged
angeloskath merged 27 commits intoml-explore:mainfrom
Blaizzy:pc/add-gemma-4
Apr 4, 2026
Merged

Add gemma 4#1093
angeloskath merged 27 commits intoml-explore:mainfrom
Blaizzy:pc/add-gemma-4

Conversation

@Blaizzy
Copy link
Copy Markdown
Contributor

@Blaizzy Blaizzy commented Apr 2, 2026

No description provided.

N8 and others added 2 commits April 2, 2026 20:55
@weklund-agent
Copy link
Copy Markdown

weklund-agent commented Apr 3, 2026

Hey @Blaizzy — great work getting this up so fast, the full model implementation looks solid (MoE, KV sharing, PLE, all the norm variants).

Wanted to flag #1103 by @0xSoftBoi which targets your gemma4 branch with some fixes and additions that build on top of this:

  • Sanitizer bug fix — the multimodal wrapper's sanitize was double-prepending model. to language model weights, causing ValueError when loading checkpoints like mlx-community/gemma-4-e2b-it-4bit
  • PLE embedding split — splits the single large per-layer embedding into individual nn.Embedding per layer to stay under Metal's 4GB buffer limit
  • Tool call parserfunction_gemma4.py for Gemma 4's native <|tool_call>...<tool_call|> JSON format, plus _infer_tool_parser integration
  • Additional tests — PLE split sanitize, MoE + k_eq_v combined test matching the real 26B-A4B config

Just a friendly heads up :)

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 3, 2026

Hey @weklund-agent @0xSoftBoi

Feel free to send a PR to my branch or I can apply cherry pick and apply the fixes here, that way we don't have duplicate work :)

@0xSoftBoi
Copy link
Copy Markdown

Hey @Blaizzy @angeloskath — great work getting this moving so fast. I've been working on Gemma 4 support in #1103 and wanted to flag a few things I've landed there that might be useful to cherry-pick or integrate here:

  1. Multimodal sanitizer fix — The current sanitizer double-prepends model. to weights that already contain it, which prevents loading multimodal checkpoints like mlx-community/gemma-4-e2b-it-4bit. Fix is a 2-line change in gemma4.py.

  2. Gemma 4 tool parser + auto-detection — A dedicated function_gemma4 parser for the <|tool_call>...<tool_call|> format with <|"|> quote unescaping, plus detection logic in tokenizer_utils.py that checks for both opening and closing delimiters. This closes Gemma 4 native tool calls are not parsed, so the OpenAI-compatible tool_calls field stays empty #1096.

  3. PLE split with Metal buffer guard — Splits HF's combined [vocab, num_layers * ple_dim] embedding tensor into per-layer chunks at sanitize time to stay under the 4GB Metal buffer limit. Includes tests for both the split path and the already-split passthrough.

  4. Test coverage — 255 lines of new tests: MoE variant (SwitchGLU routing), k_eq_v (shared K/V projections), PLE sanitize split, and a combined MoE + k_eq_v test matching the real 26B-A4B config.

All tests passing, pre-commit green, verified inference on 4-bit E2B (~50 tok/s) and LoRA fine-tuning (loss 2.66 → 0.54 over 200 iterations).

Happy to submit any of this as a PR against your branch, cherry-pick commits, or help however gets Gemma 4 landed fastest. Let me know what works best.

@0xSoftBoi
Copy link
Copy Markdown

Hey @Blaizzy — PR is up against your branch: Blaizzy#12

Rebased clean on top of pc/add-gemma-4, no conflicts, single commit. Should be a straightforward merge whenever you're ready.

For context — I originally opened #1103 because I was trying to fine-tune Gemma 4 with LoRA and kept hitting the sanitizer bug and Metal buffer crashes. Ended up fixing everything I ran into and figured I'd contribute it back. Appreciate you being open to integrating it!

@angeloskath
Copy link
Copy Markdown
Member

Hi @0xSoftBoi ! Can you share which devices encountered the 4GB limit.

@angeloskath
Copy link
Copy Markdown
Member

Feel free to PR the tool parser (and tests) separately as well cause I am about to merge this. Testing the max buffer issue you mentioned on my M2 base. Except for that everything works.

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 4, 2026

@angeloskath @0xSoftBoi

I investigated the issue and it shouldn't happen if you run the quants because all are below the limit.

Format PLE total
bf16 4.70 GB
8-bit 2.50 GB
4-bit 1.32 GB

This will only happen if you run bf16, but if the device already has a 4GB metal buffer limit then it shouldn't run because the peak for bf16 is 11.004 GB.

Feel free to PR the tool parser (and tests) separately as well cause I am about to merge this. Testing the max buffer issue you mentioned on my M2 base. Except for that everything works.

Read my mind Angelos, that's exactly what I suggested as well :)

@N8python
Copy link
Copy Markdown
Contributor

N8python commented Apr 4, 2026

Seems ready to merge to me!!!

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a bit of cleanup and thorough testing. I will merge as is but I left a couple of question and depending on answers we could remove these two pieces of code.

Comment on lines +479 to +498
# Split the sequence dimension if this still holds too much
# memory. 260k vocab means the distance tensor would be ~1GB
# per 2k tokens in bf16.
#
# If the embedding is quantized we have to dequantize it anyway to
# perform the match test.
norms_embedding = self.embed_tokens.weight.square().sum(-1)
norms_input = input_embeddings.square().sum(-1)
distance = _complete_square(
norms_embedding,
norms_input,
self.embed_tokens.as_linear(input_embeddings),
)

# Checks can be added if needed but they necessarily break the GPU
# pipelining and force an eval.
#
# match_counts = (distance < eps).sum(-1)
#
input_ids = mx.argmin(distance, -1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Blaizzy and @N8python out of curiosity when is the reverse lookup ever needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if I remember correctly, I asked N8 on X or the PR he made and he had explained it was for prefill.

But I could be wrong, and happy to take a look at it with fresh eyes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its needed in the extreme edge case where you pass input_embeddings to the model and the model needs the input_ids (not embeddings) for later layers token embeddings - so this requires:

input embeddings -> token ids -> per-layer embeddings. i think it would be cleaner to just error - codex recommended it as a fix for this edge case, and it made sense, but i'm not sure its neccessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yap, exactly. It was input embeddings not prefill 👌🏽

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be cleaner to just error

Totally. The current implementation is not horrible but if there is no use-case then it is simply dead code. I assumed it was for passing embeddings coming from a different source that doesn't map exactly to input_ids. However the original code was checking for exact match so 🤷‍♂️

Comment on lines +281 to +283
if mask is not None and isinstance(mask, mx.array):
if mask.shape[-1] != keys.shape[-2]:
mask = mask[..., -keys.shape[-2] :]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly when does this ever happen? What prompted you to add that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for sliding windows.

But I guess it can be done differently.

I wish they had separate the models into 2-3 types for clarity.

@angeloskath angeloskath merged commit 4469ad4 into ml-explore:main Apr 4, 2026
2 checks passed
lyonsno added a commit to lyonsno/mlx-lm that referenced this pull request Apr 4, 2026
Cherry-picked from upstream/main 4469ad4 onto b1-replay branch.
Resolved merge conflict in BatchRotatingKVCache.merge() — adopted
upstream's bounds-safe length slicing with loop variable unpacking.

Co-authored-by: Prince Canuma <[email protected]>
Co-authored-by: N8 <[email protected]>
Co-authored-by: Angelos Katharopoulos <[email protected]>
AustinJiangH added a commit to AustinJiangH/voixful that referenced this pull request Apr 8, 2026
User hit 'Model type gemma4 not supported' on the agent_runner
subprocess startup after switching to Gemma 4 E4B as the default.
Root cause: mlx-lm 0.31.1 (what voksly was running) doesn't have
the gemma4 model module. Gemma 4 support was added in 0.31.2,
released April 7 (ml-explore/mlx-lm#1093).

Blocker for the straightforward upgrade: mlx-audio 0.4.2 hard-pins
mlx-lm == 0.31.1 in its dependency spec, so a normal pyproject
bump fails resolution with 'mlx-audio depends on mlx-lm == 0.31.1
but your project depends on mlx-lm >= 0.31.2'.

Workaround: uv override-dependencies. Added a [tool.uv] block in
pyproject.toml that forces mlx-lm >= 0.31.2 across the resolver,
overriding mlx-audio's transitive pin. Verified end-to-end:

  $ uv sync
  - mlx-lm==0.31.1
  + mlx-lm==0.31.2

  $ uv run python -c '...'
  mlx_lm: 0.31.2
  gemma4 module: ok       ← gemma4 model class now imports
  mlx_audio: ok           ← mlx-audio still loads with newer mlx-lm
  voksly imports: ok      ← no breakage in our code path

The mlx-audio pin was overly conservative; the TTS code path uses
nothing that changed between mlx-lm 0.31.1 and 0.31.2.

With the upgrade in place, flipped the default back to Gemma 4 E4B
in all three default-resolution paths:
  - voksly.llm.DEFAULT_MLX_MODEL_KEY = 'gemma-4-e4b'
  - voksly.config.settings.llm.model = 'mlx-community/gemma-4-e4b-it-4bit'
  - frontdesk LLM dropdown first option = Gemma 4 E4B
And reordered KNOWN_MLX_MODELS so Gemma 4 E4B is iterated first
(the dropdown UI follows dict insertion order).

Drop the [tool.uv] override block when mlx-audio releases a version
with a relaxed mlx-lm constraint. Documented in pyproject.toml
comments.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants