Skip to content

[training, test] test: add MLA, MTP, and provider-override coverage for FLOPs calculator#3695

Open
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/test-mla-flops
Open

[training, test] test: add MLA, MTP, and provider-override coverage for FLOPs calculator#3695
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/test-mla-flops

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 5, 2026

What

Closes coverage gaps in tests/unit_tests/training/utils/test_flop_utils.py. The existing suite covers MoE, hybrid (Mamba / GDN / MTP), SWA, and attention-output-gate paths comprehensively, but the transformer_flops branches for Multi-Latent Attention (DeepSeek-V2/V3), explicit mtp_num_layers, and the model-provider override were untested.

The functional regression test (tests/functional_tests/test_groups/utils/test_flop_utils.py) covers only Llama 3 (8B, 70B) and Qwen3-MoE (30B-A3B, 235B-A22B). DeepSeek-V3 — the most active model family in this repo (recent issues #3475, #3499, #3538, #3597) — has no TFLOPS regression coverage, unit or functional.

This PR adds 10 unit tests across 4 new classes (337 lines, zero production-code changes).

Why

TFLOPS reporting accuracy is on the maintainers' radar (see #3498 — "VLM TFLOPS calculation is inaccurate", currently under draft fix in #3529). Any future fix to the calculator should be accompanied by airtight regression tests on the architecture variants. This PR builds that safety net for the MLA path before any production-code fix lands, so the next refactor of flop_utils.py won't silently regress DeepSeek-V3 TFLOPS.

The MLA arithmetic mirrors the closed form in flop_utils.py (lines 343-398) bit-for-bit. I verified the formulas by replicating the source path standalone and asserting equality on representative configs:

Source path total:  6,964,789,248.0
Test expected:      6,964,789,248.0
MATCH: True

New tests

Class Test What it pins down
TestMLAFlops test_mla_with_q_lora_exact_formula Bit-exact match for DeepSeek-V3-style Q-LoRA + KV-LoRA path
test_mla_without_q_lora_exact_formula Bit-exact match for the no-Q-compression branch
test_q_lora_reduces_q_projection_flops Invariant: Q-LoRA shrinks attention FLOPs when rank is small
test_mla_differs_from_standard_attention MLA path != MHA path
test_mla_batch_size_scales_linearly f(B=4) == 4 * f(B=1)
test_mla_seq_length_quadratic_growth Doubling seq grows total > 2× (core attn is O(s²))
TestMLAWithMoE test_mla_moe_combination_positive_and_distinct DeepSeek-V3 shape (MLA + MoE) totals are sensible and distinct from MHA + MoE
TestExplicitMtpInTransformerPath test_explicit_mtp_increases_flops Directional invariant
test_explicit_mtp_exact_delta Bit-exact delta covering per-layer MLP, per-layer self-attention, MTP norms/proj, and (mtp+1) logit-scaling
TestProviderOverride test_provider_override_short_circuits When the model exposes _get_num_floating_point_operations, the calculator is bypassed and the override is invoked exactly with the requested batch_size

Notes for reviewers

  • MockModelConfig is extended with q_lora_rank, kv_lora_rank, qk_head_dim, qk_pos_emb_head_dim, v_head_dim — all defaulted, so existing tests are unaffected.
  • No production code under src/megatron/bridge/ is touched.
  • Local AST parse passes; ruff lint and ruff format pass.
  • Complementary to (not blocking on) perf(fix): accumulate per-microbatch FLOPS metadata for accurate… #3529.

Test plan

  • python3 -m ast parses the file cleanly (1393 lines)
  • ruff check — all checks passed
  • ruff format --check — already formatted
  • Standalone arithmetic verification against flop_utils.py source (MLA q-LoRA, MLA no-q-LoRA, explicit MTP delta) — all bit-exact matches
  • CI to run Launch_Unit_Tests_Core (target file: tests/unit_tests/training/utils/test_flop_utils.py)

Signed-off-by

Signed-off-by: lonexreb <[email protected]>

…or FLOPs calculator

Closes coverage gaps in tests/unit_tests/training/utils/test_flop_utils.py:
the existing suite covered MoE, hybrid (Mamba/GDN/MTP), SWA, and attention
output gate, but the transformer_flops paths for Multi-Latent Attention
(DeepSeek-V2/V3), explicit mtp_num_layers, and the model-provider override
were untested. Functional tests cover only Llama / Qwen3-MoE today, so
DeepSeek-V3 — the most active model family in this repo — has had no
TFLOPS regression coverage at all.

Adds 10 unit tests across 4 new classes (337 lines, no production code):

- TestMLAFlops (6 tests)
  - test_mla_with_q_lora_exact_formula: bit-exact match against the
    closed-form MLA FLOPs (DeepSeek-V3 Q-LoRA + KV-LoRA path)
  - test_mla_without_q_lora_exact_formula: bit-exact match for the
    no-Q-compression branch (h * n_heads * (qk + qk_pos))
  - test_q_lora_reduces_q_projection_flops: invariant — Q-LoRA
    compression must reduce attention FLOPs when rank is small enough
  - test_mla_differs_from_standard_attention: MLA != MHA path
  - test_mla_batch_size_scales_linearly: f(B=4) == 4 * f(B=1)
  - test_mla_seq_length_quadratic_growth: doubling s grows total > 2x
    (because core attention is O(s^2))

- TestMLAWithMoE (1 test)
  - test_mla_moe_combination_positive_and_distinct: DeepSeek-V3 shape
    (MLA + MoE) produces sensible totals distinct from MHA + MoE

- TestExplicitMtpInTransformerPath (2 tests)
  - test_explicit_mtp_increases_flops: directional invariant
  - test_explicit_mtp_exact_delta: bit-exact delta matching the
    per-layer MLP, per-layer attention, MTP norms/proj, and (mtp+1)
    logit-scaling contributions

- TestProviderOverride (1 test)
  - test_provider_override_short_circuits: when the model exposes
    _get_num_floating_point_operations, the calculator is bypassed
    and the override is invoked exactly with the requested batch_size

The MLA arithmetic mirrors the closed form in flop_utils.py
(num_floating_point_operations -> transformer_flops, lines 343-398),
so any future regression in that branch surfaces immediately.

MockModelConfig is extended with q_lora_rank, kv_lora_rank, qk_head_dim,
qk_pos_emb_head_dim, v_head_dim — all defaulted so existing tests are
unaffected.

Signed-off-by: lonexreb <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link
Copy Markdown
Contributor

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cuichenx cuichenx added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 6, 2026
@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 6, 2026

/ok to test 2e6a881

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants