Skip to content

[training] fix: aggregate TB memory metrics across PP group#3645

Open
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/fix-tb-memory-pp-max-3167
Open

[training] fix: aggregate TB memory metrics across PP group#3645
lonexreb wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
lonexreb:training/fix-tb-memory-pp-max-3167

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 4, 2026

Summary

  • Fix TensorBoard / W&B / MLFlow / Comet memory metrics under-reporting by reducing report_memory(...) across the pipeline-parallel group with MAX before the writer rank emits.
  • Add reduce_max_memory_across_pp_group helper and call it from training_log on every PP rank, gated by the existing log_memory_to_tensorboard config flag plus the tensorboard log interval — single bulk all-reduce per logged step on PP > 1, no-op otherwise.
  • Counter-style integer keys (e.g. alloc_retries) are preserved as int so existing dashboards continue to render correctly.
  • Adds 6 unit tests for the helper covering the no-op fallbacks, happy-path bulk MAX, and the int-preservation contract.

Refs #3167. Partially addresses the TB-rank concern in #3164.

Background

tensorboard_logger, wandb_logger, mlflow_logger, and comet_logger are lazily initialized only on the last rank (world_size - 1), see state.py:182-228. With pipeline parallelism, that's the last PP rank — but peak GPU memory typically lives on the first PP stage (activation buildup), so the dashboards under-report true peak headroom and obscure how close a job is to OOM.

The cleanest fix is to aggregate per-rank torch.cuda.memory_stats() across the PP group with MAX, so the writer rank always emits the per-metric peak across the entire pipeline — strictly more useful than picking any single PP rank.

Implementation Notes

  • The all-reduce must run on every PP rank, not just the writer rank. The new pre-block runs unconditionally when log_memory_to_tensorboard=True (a config flag, identical on all ranks) and the iteration matches the tensorboard log interval. The existing loggers_exist-gated section then consumes the precomputed dict on the writer rank only.
  • Single bulk all-reduce over a stacked float64 tensor instead of one collective per metric — minimizes overhead.
  • Helper is defensive: short-circuits if memory_report is empty, torch.distributed is uninitialized, the PP group has size 1, or the group object lacks a callable .size. Keeps unit-test paths and single-stage runs unaffected.
  • Original memory_report dict is not mutated.

Test plan

  • python3 -m ast parse of changed files
  • ruff check clean on changed files
  • ruff format --check clean on changed files
  • 6 new unit tests added under TestReduceMaxMemoryAcrossPpGroup:
    • empty report short-circuits
    • distributed uninitialized → no-op
    • pp.size() == 1 → no-op (no all-reduce call)
    • PP group missing .size → defensive no-op
    • max reduction replaces values; original dict not mutated
    • int metrics (alloc_retries) preserved as int
  • CI: existing test_memory_tensorboard_logging continues to pass (helper short-circuits when torch.distributed.is_initialized() is False in unit tests)
  • CI: L0/L1 functional tests on H100 / GB200
  • Manual: with PP > 1, observe TB memory/peak_* scalars now reflect the per-metric max across pipeline stages instead of the last stage only

Risk

Low. The change is isolated to training_log's memory-logging block:

  • No behavior change when PP size == 1 or distributed is uninitialized.
  • No behavior change when log_memory_to_tensorboard=False.
  • The all-reduce participates on every PP rank only when the writer rank would log — no new collectives outside the existing log interval.
  • No public API change.

Memory metrics emitted to TensorBoard / W&B / MLFlow / Comet are written
only by the last rank (`world_size - 1`), but with pipeline parallelism
peak GPU memory typically lives on the first PP stage (activation
buildup). Without aggregation the dashboards reflect only the last PP
stage and under-report true peak headroom.

Add `reduce_max_memory_across_pp_group` and call it from `training_log`
on every PP rank before the writer-gated section, so the writer rank
emits the per-metric MAX across the pipeline. The reduction is a single
bulk all-reduce gated by the existing `log_memory_to_tensorboard`
config flag plus the tensorboard log interval, so the cost is one
collective per logged step on PP > 1, and a no-op otherwise.

Counter-style integer metrics (e.g. `alloc_retries`) are preserved as
`int` so existing dashboards continue to render them correctly. Helper
no-ops when distributed is uninitialized, the report is empty, or the
PP group has a single rank, keeping unit-test paths and single-stage
runs unaffected.

Adds unit tests for the helper covering the no-op fallbacks (empty
report, dist not initialized, pp size == 1, group missing `.size`),
the happy-path bulk MAX reduction, and the int-preservation contract.

Refs issue NVIDIA-NeMo#3167 (and partially addresses the TB-rank concern in NVIDIA-NeMo#3164).

Signed-off-by: lonexreb <[email protected]>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 5, 2026

/ok to test 2f7f48f

cuichenx
cuichenx previously approved these changes May 5, 2026
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @dingqingy-nv could you also review?

@cuichenx cuichenx added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 5, 2026
@cuichenx cuichenx linked an issue May 5, 2026 that may be closed by this pull request
@dingqingy-nv
Copy link
Copy Markdown
Contributor

/claude review

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 5, 2026

Light Code Review

LGTM — clean implementation with solid test coverage.

Notes

  • The all-reduce placement is correct: it runs on all PP ranks (outside the loggers_exist gate) so every rank participates in the collective, while the writer rank consumes the result inside the gated block. The gating condition (log_memory_to_tensorboard + iteration check) is config-derived and identical on all ranks, so no deadlock risk.
  • The bool guard (not isinstance(original, bool) and isinstance(original, int)) is good defensive coding since bool is a subclass of int in Python, even though report_memory currently never returns bools.
  • Dict key ordering is deterministic (Python 3.7+ insertion order) and identical across ranks since all ranks call the same report_memory function, so the stacked tensor aligns correctly.
  • Original dict is not mutated — verified by test.

Suggested test cases

No perf tests impacted.

dingqingy-nv
dingqingy-nv previously approved these changes May 5, 2026
@cuichenx cuichenx enabled auto-merge (squash) May 5, 2026 18:32
@cuichenx cuichenx dismissed stale reviews from dingqingy-nv and themself via 1f86d45 May 5, 2026 18:33
@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 5, 2026

/ok to test 1f86d45

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] improve TB memory logging

4 participants