Skip to content

[training, test] feat: add unit tests for pg_utils helpers#3650

Open
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/test-pg-utils
Open

[training, test] feat: add unit tests for pg_utils helpers#3650
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/test-pg-utils

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 4, 2026

Summary

src/megatron/bridge/training/utils/pg_utils.py contains two key helpers used across the training stack but had no direct unit-test coverage:

  • get_pg_collection(model) — accessor that returns a model's pg_collection, with a documented MPU fallback.
  • DistTrainProcessGroupCollection — process-group wrapper used by the dist-train multi-module pipeline.

Existing references in tests/unit_tests/training/test_decentralized_pg.py only isinstance-check the result — they don't exercise the constructor's field-copy logic, the language-model branching, or get_pg_collection's control flow.

This PR adds 10 focused unit tests for both helpers. Tests-only — no production code changes.

What's covered

get_pg_collection (4 tests)

Test Path covered
test_returns_pg_collection_attached_to_single_model Happy path: model has pg_collection → return it
test_uses_first_chunk_when_passed_a_list List-of-chunks path (used when models wrap in lists for VPP) → model[0] is referenced
test_falls_back_to_mpu_when_pg_collection_attribute_missing The documented contract: get_attr_wrapped_model raises with the exact "couldn't find attribute pg_collection" substring → ProcessGroupCollection.use_mpu_process_groups() is invoked
test_reraises_runtime_error_when_message_does_not_match An unrelated RuntimeError propagates unchanged AND the MPU fallback is not invoked

DistTrainProcessGroupCollection (6 tests)

Test Behavior locked in
test_inherits_process_group_collection Subclass relationship (so callers can isinstance)
test_copies_set_fields_from_source_collection All set fields (tp, pp, cp) on the source are copied to the wrapper
test_unset_fields_default_to_none Fields never set on the source default to None on the wrapper (constructor uses getattr(..., None))
test_no_language_model_by_default language_model_module_name=None keeps has_language_model() False and language_model_collection None
test_language_model_attaches_source_collection language_model_module_name="llm" flips has_language_model() True and exposes the source as language_model_collection
test_get_language_model_cp_size_returns_lm_cp_size Returns cp.size() on the LM collection
test_get_language_model_cp_size_raises_without_lm Raises ValueError("No language model specified") when no LM is configured

Why this matters

  • The MPU fallback behavior in get_pg_collection is brittle: it depends on a substring match against RuntimeError's message. If MCore ever changes the error message, this fallback silently breaks. A test that asserts the substring contract makes that breakage loud.
  • DistTrainProcessGroupCollection's constructor uses getattr(..., None) for every field — without coverage, regressions that change defaults (e.g., raising vs returning None) would slip through.

Test plan

  • python3 -m ast parse clean
  • ruff check clean
  • ruff format --check clean
  • CI: cicd-unit-tests-core picks up the new module automatically (lives under tests/unit_tests/training/utils/)

Risk

Zero — tests only. No production code, no public API touched. The contracts under test have been stable since the file was added; this just locks them in so future refactors fail fast if behavior drifts.

Notes for reviewers

  • The MPU fallback test patches ProcessGroupCollection.use_mpu_process_groups to a sentinel return rather than letting it touch parallel_state, so the test runs cleanly without distributed init.
  • The wrapper test uses bare ProcessGroupCollection() (all fields are init=False) plus setattr for the fields we actually care about — keeps the test surface small and resilient to ProcessGroupCollection field additions in MCore.

`src/megatron/bridge/training/utils/pg_utils.py` had no direct test
coverage. Existing references in `tests/unit_tests/training/test_decentralized_pg.py`
only check `isinstance(result, DistTrainProcessGroupCollection)` — they
don't exercise the constructor field-copy logic, the `language_model_*`
behavior, or `get_pg_collection`'s control flow.

Add focused unit tests covering:

`get_pg_collection`:
  - Returns the model's pg_collection when present.
  - Uses `model[0]` when given a list of chunks (the recipe path used
    when models are wrapped in lists for VPP).
  - Falls back to `ProcessGroupCollection.use_mpu_process_groups()` when
    `get_attr_wrapped_model` raises with the exact "couldn't find
    attribute pg_collection" message — the documented contract.
  - Re-raises any other RuntimeError unchanged (and does NOT invoke the
    MPU fallback in that case).

`DistTrainProcessGroupCollection`:
  - Inherits from `ProcessGroupCollection` (callers `isinstance` it).
  - Copies set fields from the source pg_collection (`tp`, `pp`, `cp`).
  - Unset source fields default to `None` on the wrapper (the
    constructor uses `getattr(..., None)`).
  - `language_model_module_name=None` keeps `has_language_model()` False
    and `language_model_collection` None.
  - `language_model_module_name="llm"` flips `has_language_model()` to
    True and exposes the source pg_collection as `language_model_collection`.
  - `get_language_model_cp_size()` returns the LM CP group's `.size()`.
  - `get_language_model_cp_size()` raises ValueError with a clear
    message when no LM is configured.

Tests-only — no production code changes. The contract under test has
been stable since the file was added; this just locks it in so future
refactors fail fast if the behavior drifts.

Signed-off-by: lonexreb <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 5, 2026

/ok to test b0bd1e2

@cuichenx cuichenx added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants