Skip to content

Commit 4469ad4

Browse files
BlaizzyN8angeloskath
authored
Add gemma 4 (#1093)
Co-authored-by: N8 <[email protected]> Co-authored-by: Angelos Katharopoulos <[email protected]>
1 parent f79dba7 commit 4469ad4

File tree

5 files changed

+1238
-2
lines changed

5 files changed

+1238
-2
lines changed

mlx_lm/models/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,8 +1424,8 @@ def merge(cls, caches):
14241424
for i, (p, l, c) in enumerate(zip(padding, lengths, caches)):
14251425
if c.keys is None:
14261426
continue
1427-
keys[i : i + 1, :, p : p + l] = c._temporal_order(c.keys)
1428-
values[i : i + 1, :, p : p + l] = c._temporal_order(c.values)
1427+
keys[i : i + 1, :, p : p + l] = c._temporal_order(c.keys)[..., -l:, :]
1428+
values[i : i + 1, :, p : p + l] = c._temporal_order(c.values)[..., -l:, :]
14291429

14301430
cache = cls(caches[0].max_size, padding)
14311431
cache.keys = keys

mlx_lm/models/gemma4.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
from mlx.utils import tree_flatten, tree_unflatten
9+
10+
from . import gemma4_text
11+
from .base import BaseModelArgs
12+
13+
14+
@dataclass
15+
class ModelArgs(BaseModelArgs):
16+
model_type: str = "gemma4"
17+
text_config: dict = None
18+
vocab_size: int = 262144
19+
20+
def __post_init__(self):
21+
if self.text_config is None:
22+
self.text_config = {}
23+
self.text_config["vocab_size"] = self.vocab_size
24+
self.text_config["num_attention_heads"] = self.text_config.get(
25+
"num_attention_heads", 8
26+
)
27+
self.text_config["num_key_value_heads"] = self.text_config.get(
28+
"num_key_value_heads", 1
29+
)
30+
31+
32+
class Model(nn.Module):
33+
def __init__(self, args: ModelArgs):
34+
super().__init__()
35+
self.args = args
36+
self.model_type = args.model_type
37+
self.language_model = gemma4_text.Model(
38+
gemma4_text.ModelArgs.from_dict(args.text_config)
39+
)
40+
41+
def __call__(
42+
self,
43+
inputs: mx.array,
44+
cache=None,
45+
input_embeddings: Optional[mx.array] = None,
46+
per_layer_inputs: Optional[mx.array] = None,
47+
):
48+
return self.language_model(
49+
inputs,
50+
cache=cache,
51+
input_embeddings=input_embeddings,
52+
per_layer_inputs=per_layer_inputs,
53+
)
54+
55+
def sanitize(self, weights):
56+
new_weights = {}
57+
for k, v in weights.items():
58+
starts_w_model = k.startswith("model.")
59+
60+
k = k.removeprefix("model.")
61+
if k.startswith(
62+
(
63+
"vision_tower",
64+
"multi_modal_projector",
65+
"audio_tower",
66+
"embed_audio",
67+
"embed_vision",
68+
)
69+
):
70+
continue
71+
72+
if not starts_w_model:
73+
new_weights[k] = v
74+
continue
75+
76+
if k.startswith("language_model"):
77+
k = k.replace("language_model.", "language_model.model.")
78+
79+
new_weights[k] = v
80+
81+
return self.language_model.sanitize(new_weights)
82+
83+
@property
84+
def layers(self):
85+
return self.language_model.layers
86+
87+
@property
88+
def quant_predicate(self):
89+
return self.language_model.quant_predicate
90+
91+
def make_cache(self):
92+
return self.language_model.make_cache()

0 commit comments

Comments
 (0)