[model] feat: add bridge support for OLMo-2 dense causal LMs#3698
Draft
lonexreb wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Draft
[model] feat: add bridge support for OLMo-2 dense causal LMs#3698lonexreb wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
lonexreb wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Adds AllenAI's OLMo-2 family (1B / 7B / 13B) to the supported model registry. OLMo-2 is the second-generation fully-open language model (Yang et al., 2024, https://arxiv.org/abs/2501.00656). It is the first post-norm architecture in the bridge — distinct from the existing pre-norm and Gemma2 sandwich-norm patterns. ## Architectural distinction OLMo-2's decoder block is pure post-norm: x = x + post_attention_layernorm(self_attn(x)) x = x + post_feedforward_layernorm(mlp(x)) with **no** input_layernorm or pre_feedforward_layernorm. Compared to existing bridges: | Property | Llama | Qwen3 | Gemma2 (sandwich) | OLMo-2 | |----------------------|-------|-------|-------------------|--------| | Pre-attn / pre-MLP | yes | yes | yes | no | | Post-attn / post-MLP | no | no | yes | yes | | QK-RMSNorm | no | yes | no | yes | ## Implementation * `olmo2_provider.py` — defines `olmo2_layer_spec` which uses `IdentityOp` for `input_layernorm` / `pre_mlp_layernorm` and wraps `linear_proj` / `linear_fc2` in `TERowParallelLinearPostLN` (a local RMSNorm-after subclass of `TERowParallelLinear`, structurally equivalent to Gemma2's `TERowParallelLinearLayerNorm`). QK-RMSNorm flows through the standard `q_layernorm` / `k_layernorm` submodule slots activated by `provider.qk_layernorm = True`. * `olmo2_bridge.py` — registers `Olmo2Bridge` for HF `Olmo2ForCausalLM`. The mapping registry routes `post_attention_layernorm` / `post_feedforward_layernorm` weights into the `linear_proj.post_layernorm` / `linear_fc2.post_layernorm` slots — NOT into the Llama-style `linear_qkv.layer_norm_weight` slot which would silently produce a wrong-architecture model. * Pre-built size variants: `Olmo2ModelProvider1B`, `7B`, `13B` with dimensions matching the HF default configs. ## Tests `tests/unit_tests/models/olmo2/test_olmo2_bridge.py` (+431 lines, all CPU, no GPU). Pins: * QK-layernorm and post-norm layer spec are flagged on the provider. * Architectural defaults (RMSNorm, SwiGLU, no biases, RoPE, untied embeddings, layernorm_epsilon=1e-6, rotary_base=500000). * Numerical config translation for both 1B and 7B HF configs, including head_dim derivation when missing. * Mapping registry routes the OLMo-2-specific output norms into the correct `post_layernorm` slots. * Negative tests: no mapping ever writes to `linear_qkv.layer_norm_weight` or `linear_fc1.layer_norm_weight` (would indicate accidental Llama-style routing). * Layer spec structure: pre-norm slots are IdentityOp; q/k_layernorm slots are TENorm; linear_qkv / linear_fc1 are plain TEColumnParallelLinear (no embedded LN); linear_proj / linear_fc2 are TERowParallelLinearPostLN. * QKV and gated-MLP weights are fused via QKVMapping / GatedMLPMapping. * Size-variant providers (1B/7B/13B) inherit OLMo-2 defaults. ## Validation Local: `ruff check` + `ruff format --check` clean across all 6 new/ modified files. Full pytest cannot run on macOS (CUDA-only dependencies); CI L0 will pick up the new test module automatically. ## Notes for reviewers * `TERowParallelLinearPostLN` is structurally identical to Gemma2's `TERowParallelLinearLayerNorm`. Per the project's "keep model-specific logic in the family directory" guideline I defined it locally rather than importing across families. If maintainers prefer to promote it to a shared utility I can do that in a follow-up. * `source` is registered as the string `"Olmo2ForCausalLM"` so the bridge does not hard-fail on transformers versions that predate OLMo-2 support; tests skip themselves under those versions. * Recipes are intentionally not included in this PR — they should be a follow-up once the bridge has CI validation. Signed-off-by: lonexreb <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds AllenAI's OLMo-2 dense family (1B / 7B / 13B) — the second-generation fully-open language model from the Allen Institute (Yang et al., 2024, 2 OLMo 2 Furious).
OLMo-2 is the first post-norm architecture in the bridge — distinct from existing pre-norm bridges (Llama/Qwen3/Mistral) and from Gemma2's sandwich-norm. It is heavily used in academic research because the entire training pipeline (data, code, weights, intermediate checkpoints) is open.
This is opened as a draft because the layer spec cannot be validated without CUDA-only runtime; CI L0 functional tests will be the first true forward-pass run. Architecture/mapping review is welcome now.
Architectural distinction
OLMo-2's decoder block:
There is no
input_layernormand nopre_feedforward_layernorm. Compared to existing bridges:Implementation (
src/megatron/bridge/models/olmo2/)Provider + custom layer spec (
olmo2_provider.py)olmo2_layer_specusesIdentityOpforinput_layernormandpre_mlp_layernormso the pre-block normalizations are no-ops.linear_projandlinear_fc2are wrapped inTERowParallelLinearPostLN— a local RMSNorm-after subclass ofTERowParallelLinear. This is structurally identical to Gemma2'sTERowParallelLinearLayerNorm(linked:src/megatron/bridge/models/gemma/gemma2_provider.py:255-277). Defined locally per the project's keep model-specific logic in the family directory guideline; happy to promote to a shared utility if maintainers prefer.q_layernorm/k_layernormsubmodule slots, activated byprovider.qk_layernorm = True(same pattern as Qwen3 and OLMoE).Olmo2ModelProvider1B,...7B,...13Bwith dimensions matching the HF default configs (1B: 16/2048/16/8192, 7B: 32/4096/32/11008, 13B: 40/5120/40/13824).Bridge + weight mappings (
olmo2_bridge.py)The
mapping_registryroutes OLMo-2's output norms into the correct slots:model.layers.*.post_attention_layernorm.weightdecoder.layers.*.self_attention.linear_proj.post_layernorm.weightmodel.layers.*.post_feedforward_layernorm.weightdecoder.layers.*.mlp.linear_fc2.post_layernorm.weightmodel.layers.*.self_attn.q_norm.weightdecoder.layers.*.self_attention.q_layernorm.weightmodel.layers.*.self_attn.k_norm.weightdecoder.layers.*.self_attention.k_layernorm.weightCrucially, these do NOT map to
linear_qkv.layer_norm_weight(the Llama/Qwen3 pre-MLP slot) — wrong routing there would silently produce a wrong-architecture model that loads weights without raising. Negative tests guard against this regression.QKV (q/k/v) and gated-SwiGLU MLP (gate/up) are fused via
QKVMappingandGatedMLPMappingexactly as in other dense bridges. No biases (attention_bias=False).sourceis registered as the string"Olmo2ForCausalLM"so the bridge module imports cleanly on transformers versions that predate OLMo-2 support; tests guard with apytest.importorskip-style skip.Tests (
tests/unit_tests/models/olmo2/test_olmo2_bridge.py, +431 LoC)All CPU-only. Coverage:
TestOlmo2BridgeRegistrationMegatronModelBridge, source class bindingTestOlmo2ProviderBridgeArchitecturalFlagsqk_layernorm=True, post-norm layer spec selected, no biases, SwiGLU, RMSNorm, RoPE, untied embeddingsTestOlmo2ProviderBridgeShapeFieldskv_channelsderivation whenhead_dimmissing AND when presentTestOlmo2MappingRegistrylinear_qkv.layer_norm_weightorlinear_fc1.layer_norm_weight(would indicate accidental Llama-style routing); QKV/Gated-MLP mappings present and have correct slot names; no QKV bias mappingsTestOlmo2LayerSpecinput_layernormandpre_mlp_layernormareIdentityOp;q_layernorm/k_layernormareTENorm;linear_qkv/linear_fc1are plainTEColumnParallelLinear;linear_proj/linear_fc2areTERowParallelLinearPostLNTestOlmo2ModelProviderSizeVariantsTestOlmo2ProviderBaseDefaultspersist_layer_normValidation
ruff checkclean across all 6 new/modified filesruff format --checkcleantransformers/models/olmo2/modeling_olmo2.py(HF reference) andallenai/OLMo-2-{0425-1B, 1124-7B, 1124-13B}/config.jsonRisk
olmo2_provider.py:olmo2_layer_spec) is the highest-value review I'm requesting.modeling_olmo2.py; negative tests guard against the most likely silent failure mode (writing the post-norm weights into the pre-norm slot).Out of scope (follow-ups)
recipes/olmo2/olmo2.py) — should land after this PR has CI greenexamples/models/olmo2/conversion + inference scripts on a real checkpointtests/functional_tests/models/olmo2/)TERowParallelLinearPostLNto a shared utility used by both Gemma2 and OLMo-2References
src/megatron/bridge/models/gemma/gemma2_provider.py(post-norm wrapper pattern),src/megatron/bridge/models/qwen/qwen3_bridge.py(QK-norm pattern)