[training, test] feat: add unit tests for pg_utils helpers#3650
Open
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Open
[training, test] feat: add unit tests for pg_utils helpers#3650lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
`src/megatron/bridge/training/utils/pg_utils.py` had no direct test
coverage. Existing references in `tests/unit_tests/training/test_decentralized_pg.py`
only check `isinstance(result, DistTrainProcessGroupCollection)` — they
don't exercise the constructor field-copy logic, the `language_model_*`
behavior, or `get_pg_collection`'s control flow.
Add focused unit tests covering:
`get_pg_collection`:
- Returns the model's pg_collection when present.
- Uses `model[0]` when given a list of chunks (the recipe path used
when models are wrapped in lists for VPP).
- Falls back to `ProcessGroupCollection.use_mpu_process_groups()` when
`get_attr_wrapped_model` raises with the exact "couldn't find
attribute pg_collection" message — the documented contract.
- Re-raises any other RuntimeError unchanged (and does NOT invoke the
MPU fallback in that case).
`DistTrainProcessGroupCollection`:
- Inherits from `ProcessGroupCollection` (callers `isinstance` it).
- Copies set fields from the source pg_collection (`tp`, `pp`, `cp`).
- Unset source fields default to `None` on the wrapper (the
constructor uses `getattr(..., None)`).
- `language_model_module_name=None` keeps `has_language_model()` False
and `language_model_collection` None.
- `language_model_module_name="llm"` flips `has_language_model()` to
True and exposes the source pg_collection as `language_model_collection`.
- `get_language_model_cp_size()` returns the LM CP group's `.size()`.
- `get_language_model_cp_size()` raises ValueError with a clear
message when no LM is configured.
Tests-only — no production code changes. The contract under test has
been stable since the file was added; this just locks it in so future
refactors fail fast if the behavior drifts.
Signed-off-by: lonexreb <[email protected]>
Contributor
|
/ok to test b0bd1e2 |
cuichenx
approved these changes
May 5, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
src/megatron/bridge/training/utils/pg_utils.pycontains two key helpers used across the training stack but had no direct unit-test coverage:get_pg_collection(model)— accessor that returns a model'spg_collection, with a documented MPU fallback.DistTrainProcessGroupCollection— process-group wrapper used by the dist-train multi-module pipeline.Existing references in
tests/unit_tests/training/test_decentralized_pg.pyonlyisinstance-check the result — they don't exercise the constructor's field-copy logic, the language-model branching, orget_pg_collection's control flow.This PR adds 10 focused unit tests for both helpers. Tests-only — no production code changes.
What's covered
get_pg_collection(4 tests)test_returns_pg_collection_attached_to_single_modelpg_collection→ return ittest_uses_first_chunk_when_passed_a_listmodel[0]is referencedtest_falls_back_to_mpu_when_pg_collection_attribute_missingget_attr_wrapped_modelraises with the exact"couldn't find attribute pg_collection"substring →ProcessGroupCollection.use_mpu_process_groups()is invokedtest_reraises_runtime_error_when_message_does_not_matchRuntimeErrorpropagates unchanged AND the MPU fallback is not invokedDistTrainProcessGroupCollection(6 tests)test_inherits_process_group_collectionisinstance)test_copies_set_fields_from_source_collectiontp,pp,cp) on the source are copied to the wrappertest_unset_fields_default_to_noneNoneon the wrapper (constructor usesgetattr(..., None))test_no_language_model_by_defaultlanguage_model_module_name=Nonekeepshas_language_model()False andlanguage_model_collectionNonetest_language_model_attaches_source_collectionlanguage_model_module_name="llm"flipshas_language_model()True and exposes the source aslanguage_model_collectiontest_get_language_model_cp_size_returns_lm_cp_sizecp.size()on the LM collectiontest_get_language_model_cp_size_raises_without_lmValueError("No language model specified")when no LM is configuredWhy this matters
get_pg_collectionis brittle: it depends on a substring match againstRuntimeError's message. If MCore ever changes the error message, this fallback silently breaks. A test that asserts the substring contract makes that breakage loud.DistTrainProcessGroupCollection's constructor usesgetattr(..., None)for every field — without coverage, regressions that change defaults (e.g., raising vs returning None) would slip through.Test plan
python3 -m astparse cleanruff checkcleanruff format --checkcleancicd-unit-tests-corepicks up the new module automatically (lives undertests/unit_tests/training/utils/)Risk
Zero — tests only. No production code, no public API touched. The contracts under test have been stable since the file was added; this just locks them in so future refactors fail fast if behavior drifts.
Notes for reviewers
ProcessGroupCollection.use_mpu_process_groupsto a sentinel return rather than letting it touchparallel_state, so the test runs cleanly without distributed init.ProcessGroupCollection()(all fields areinit=False) plussetattrfor the fields we actually care about — keeps the test surface small and resilient to ProcessGroupCollection field additions in MCore.