Skip to content

[model] feat: add bridge support for OLMo-2 dense causal LMs#3698

Draft
lonexreb wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
lonexreb:feat/olmo2-bridge
Draft

[model] feat: add bridge support for OLMo-2 dense causal LMs#3698
lonexreb wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
lonexreb:feat/olmo2-bridge

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 5, 2026

Summary

Adds AllenAI's OLMo-2 dense family (1B / 7B / 13B) — the second-generation fully-open language model from the Allen Institute (Yang et al., 2024, 2 OLMo 2 Furious).

OLMo-2 is the first post-norm architecture in the bridge — distinct from existing pre-norm bridges (Llama/Qwen3/Mistral) and from Gemma2's sandwich-norm. It is heavily used in academic research because the entire training pipeline (data, code, weights, intermediate checkpoints) is open.

This is opened as a draft because the layer spec cannot be validated without CUDA-only runtime; CI L0 functional tests will be the first true forward-pass run. Architecture/mapping review is welcome now.

Architectural distinction

OLMo-2's decoder block:

# from transformers/src/transformers/models/olmo2/modeling_olmo2.py
residual = hidden_states
hidden_states, _ = self.self_attn(...)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states

There is no input_layernorm and no pre_feedforward_layernorm. Compared to existing bridges:

Property Llama Qwen3 Gemma2 (sandwich) OLMo-2
Pre-attn / pre-MLP norm yes yes yes no
Post-attn / post-MLP norm no no yes yes
QK-RMSNorm no yes no yes
Logit soft-capping no no yes no

Implementation (src/megatron/bridge/models/olmo2/)

Provider + custom layer spec (olmo2_provider.py)

  • olmo2_layer_spec uses IdentityOp for input_layernorm and pre_mlp_layernorm so the pre-block normalizations are no-ops.
  • linear_proj and linear_fc2 are wrapped in TERowParallelLinearPostLN — a local RMSNorm-after subclass of TERowParallelLinear. This is structurally identical to Gemma2's TERowParallelLinearLayerNorm (linked: src/megatron/bridge/models/gemma/gemma2_provider.py:255-277). Defined locally per the project's keep model-specific logic in the family directory guideline; happy to promote to a shared utility if maintainers prefer.
  • QK-RMSNorm flows through the standard q_layernorm / k_layernorm submodule slots, activated by provider.qk_layernorm = True (same pattern as Qwen3 and OLMoE).
  • Size variants Olmo2ModelProvider1B, ...7B, ...13B with dimensions matching the HF default configs (1B: 16/2048/16/8192, 7B: 32/4096/32/11008, 13B: 40/5120/40/13824).

Bridge + weight mappings (olmo2_bridge.py)

The mapping_registry routes OLMo-2's output norms into the correct slots:

HF parameter Megatron parameter
model.layers.*.post_attention_layernorm.weight decoder.layers.*.self_attention.linear_proj.post_layernorm.weight
model.layers.*.post_feedforward_layernorm.weight decoder.layers.*.mlp.linear_fc2.post_layernorm.weight
model.layers.*.self_attn.q_norm.weight decoder.layers.*.self_attention.q_layernorm.weight
model.layers.*.self_attn.k_norm.weight decoder.layers.*.self_attention.k_layernorm.weight

Crucially, these do NOT map to linear_qkv.layer_norm_weight (the Llama/Qwen3 pre-MLP slot) — wrong routing there would silently produce a wrong-architecture model that loads weights without raising. Negative tests guard against this regression.

QKV (q/k/v) and gated-SwiGLU MLP (gate/up) are fused via QKVMapping and GatedMLPMapping exactly as in other dense bridges. No biases (attention_bias=False).

source is registered as the string "Olmo2ForCausalLM" so the bridge module imports cleanly on transformers versions that predate OLMo-2 support; tests guard with a pytest.importorskip-style skip.

Tests (tests/unit_tests/models/olmo2/test_olmo2_bridge.py, +431 LoC)

All CPU-only. Coverage:

Test class What it pins
TestOlmo2BridgeRegistration Bridge inherits MegatronModelBridge, source class binding
TestOlmo2ProviderBridgeArchitecturalFlags qk_layernorm=True, post-norm layer spec selected, no biases, SwiGLU, RMSNorm, RoPE, untied embeddings
TestOlmo2ProviderBridgeShapeFields Numerical config translation for 1B and 7B HF configs, kv_channels derivation when head_dim missing AND when present
TestOlmo2MappingRegistry Every weight name routes correctly; negative tests confirm no mapping ever writes linear_qkv.layer_norm_weight or linear_fc1.layer_norm_weight (would indicate accidental Llama-style routing); QKV/Gated-MLP mappings present and have correct slot names; no QKV bias mappings
TestOlmo2LayerSpec input_layernorm and pre_mlp_layernorm are IdentityOp; q_layernorm/k_layernorm are TENorm; linear_qkv/linear_fc1 are plain TEColumnParallelLinear; linear_proj/linear_fc2 are TERowParallelLinearPostLN
TestOlmo2ModelProviderSizeVariants 1B/7B/13B variants inherit OLMo-2 defaults; dimensions match HF configs
TestOlmo2ProviderBaseDefaults Base provider picks up post-norm spec and persist_layer_norm

Validation

  • ruff check clean across all 6 new/modified files
  • ruff format --check clean
  • Python AST parse clean
  • Architectural details cross-checked against transformers/models/olmo2/modeling_olmo2.py (HF reference) and allenai/OLMo-2-{0425-1B, 1124-7B, 1124-13B}/config.json
  • CI L0 unit tests — picks up new module automatically (path-based discovery)
  • CI functional / smoke test on a real HF checkpoint — out of scope for this PR; recipes deferred until the bridge has CI green

Risk

  • Build risk: low. No changes to existing bridges, conversion infrastructure, or shared params; only additions in a new family directory.
  • Architectural risk: medium. The layer spec is structurally close to Gemma2's, but I cannot validate the full forward pass on macOS. Maintainer review of the layer spec (olmo2_provider.py:olmo2_layer_spec) is the highest-value review I'm requesting.
  • Mapping risk: low. Mapping registry verified against HF state-dict naming via direct read of modeling_olmo2.py; negative tests guard against the most likely silent failure mode (writing the post-norm weights into the pre-norm slot).

Out of scope (follow-ups)

  • Training recipes (recipes/olmo2/olmo2.py) — should land after this PR has CI green
  • examples/models/olmo2/ conversion + inference scripts on a real checkpoint
  • Functional roundtrip test (tests/functional_tests/models/olmo2/)
  • Optional: promote TERowParallelLinearPostLN to a shared utility used by both Gemma2 and OLMo-2

References

Adds AllenAI's OLMo-2 family (1B / 7B / 13B) to the supported model
registry. OLMo-2 is the second-generation fully-open language model
(Yang et al., 2024, https://arxiv.org/abs/2501.00656). It is the first
post-norm architecture in the bridge — distinct from the existing
pre-norm and Gemma2 sandwich-norm patterns.

## Architectural distinction

OLMo-2's decoder block is pure post-norm:

    x = x + post_attention_layernorm(self_attn(x))
    x = x + post_feedforward_layernorm(mlp(x))

with **no** input_layernorm or pre_feedforward_layernorm. Compared to
existing bridges:

| Property             | Llama | Qwen3 | Gemma2 (sandwich) | OLMo-2 |
|----------------------|-------|-------|-------------------|--------|
| Pre-attn / pre-MLP   | yes   | yes   | yes               | no     |
| Post-attn / post-MLP | no    | no    | yes               | yes    |
| QK-RMSNorm           | no    | yes   | no                | yes    |

## Implementation

* `olmo2_provider.py` — defines `olmo2_layer_spec` which uses
  `IdentityOp` for `input_layernorm` / `pre_mlp_layernorm` and wraps
  `linear_proj` / `linear_fc2` in `TERowParallelLinearPostLN` (a local
  RMSNorm-after subclass of `TERowParallelLinear`, structurally
  equivalent to Gemma2's `TERowParallelLinearLayerNorm`). QK-RMSNorm
  flows through the standard `q_layernorm` / `k_layernorm` submodule
  slots activated by `provider.qk_layernorm = True`.
* `olmo2_bridge.py` — registers `Olmo2Bridge` for HF
  `Olmo2ForCausalLM`. The mapping registry routes
  `post_attention_layernorm` / `post_feedforward_layernorm` weights
  into the `linear_proj.post_layernorm` / `linear_fc2.post_layernorm`
  slots — NOT into the Llama-style `linear_qkv.layer_norm_weight`
  slot which would silently produce a wrong-architecture model.
* Pre-built size variants: `Olmo2ModelProvider1B`, `7B`, `13B` with
  dimensions matching the HF default configs.

## Tests

`tests/unit_tests/models/olmo2/test_olmo2_bridge.py` (+431 lines, all
CPU, no GPU). Pins:

* QK-layernorm and post-norm layer spec are flagged on the provider.
* Architectural defaults (RMSNorm, SwiGLU, no biases, RoPE, untied
  embeddings, layernorm_epsilon=1e-6, rotary_base=500000).
* Numerical config translation for both 1B and 7B HF configs,
  including head_dim derivation when missing.
* Mapping registry routes the OLMo-2-specific output norms into the
  correct `post_layernorm` slots.
* Negative tests: no mapping ever writes to
  `linear_qkv.layer_norm_weight` or `linear_fc1.layer_norm_weight`
  (would indicate accidental Llama-style routing).
* Layer spec structure: pre-norm slots are IdentityOp; q/k_layernorm
  slots are TENorm; linear_qkv / linear_fc1 are plain
  TEColumnParallelLinear (no embedded LN); linear_proj / linear_fc2
  are TERowParallelLinearPostLN.
* QKV and gated-MLP weights are fused via QKVMapping / GatedMLPMapping.
* Size-variant providers (1B/7B/13B) inherit OLMo-2 defaults.

## Validation

Local: `ruff check` + `ruff format --check` clean across all 6 new/
modified files. Full pytest cannot run on macOS (CUDA-only
dependencies); CI L0 will pick up the new test module automatically.

## Notes for reviewers

* `TERowParallelLinearPostLN` is structurally identical to Gemma2's
  `TERowParallelLinearLayerNorm`. Per the project's "keep
  model-specific logic in the family directory" guideline I
  defined it locally rather than importing across families. If
  maintainers prefer to promote it to a shared utility I can do that
  in a follow-up.
* `source` is registered as the string `"Olmo2ForCausalLM"` so the
  bridge does not hard-fail on transformers versions that predate
  OLMo-2 support; tests skip themselves under those versions.
* Recipes are intentionally not included in this PR — they should be
  a follow-up once the bridge has CI validation.

Signed-off-by: lonexreb <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants