Skip to content

perf(fix): accumulate per-microbatch FLOPS metadata for accurate…#3529

Open
SophusDavid wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
SophusDavid:merge-to-upstream
Open

perf(fix): accumulate per-microbatch FLOPS metadata for accurate…#3529
SophusDavid wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
SophusDavid:merge-to-upstream

Conversation

@SophusDavid
Copy link
Copy Markdown

@SophusDavid SophusDavid commented Apr 26, 2026

perf(training): improve VLM TFLOPS calculation with actual sequence length and ViT

What does this PR do ?

Fixes inaccurate TFLOPS reporting for VLM training where the reported FLOPS was based on the statically configured cfg.model.seq_length instead of the dynamically padded sequence lengths that FlashAttention actually processes. This led to significant overestimation (up to ~2.6x for typical VLM SFT scenarios with short sequences + images).

The PR introduces a per-micro-batch accumulator pattern that captures the real padded sequence length and vision patch count from each forward_step invocation, scales them by data-parallel size for global estimation, and feeds them into an extended num_floating_point_operations() API.

Closes #3498

Changelog

  • vlm_step.py

    • Accumulate _flops_seqlen_sum and _flops_seqlen_sq_sum across micro-batches using tokens.shape[1] (actual padded length).
    • Accumulate _flops_vision_patches from visual_inputs.image_grid_thw / video_grid_thw.
    • Add "total_tokens" to packed_seq_params for downstream visibility.
  • train.py

    • Reset _flops_* accumulators to 0 before each wrapped_train_step call.
    • After the step, read accumulated per-rank totals, scale by dp_size for global FLOPS, and pass effective seq_length to training_log.
    • Remove the earlier seq_length return value from train_step (reverting to original 8-element return signature).
  • flop_utils.py

    • Add vit_flops() helper for Vision Transformer encoder FLOPS (bidirectional attention + GELU MLP + patch merger).
    • Extend num_floating_point_operations() signature with optional seqlen_sum, seqlen_squared_sum, and num_vision_patches parameters.
    • Add _compute_vit_flops() that derives per-image patch count from total batch patches for correct quadratic attention scaling.
  • train_utils.py

    • training_log() now accepts seq_length parameter and uses accumulated _flops_vision_patches (scaled by data_parallel_size) for consistent per-log FLOPS recomputation.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
    • Added test_flops_calculation.py with 13 verification tests covering VLM and non-VLM scenarios.
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

Additional Information

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

… VLM TFLOPS

- Replace static cfg.model.seq_length with accumulated actual padded
  lengths from each micro-batch (_flops_seqlen_sum/sq_sum).
- Add ViT encoder FLOPS via _flops_vision_patches from visual grid THW.
- Scale per-rank accumulators by dp_size for global FLOPS estimation.
- Revert train_step return signature; zero API surface changes.
- Also support trace_prefix in pytorch profiler initialization.

Signed-off-by: dingtianwei.dtw <[email protected]>
@yaoyu-33 yaoyu-33 added area:perf Performance optimizations and benchmarking bug Something isn't working needs-review PR is ready for code review and waiting on a reviewer labels Apr 27, 2026
@yaoyu-33
Copy link
Copy Markdown
Contributor

/claude review

Comment thread src/megatron/bridge/training/utils/flop_utils.py Outdated
Comment thread src/megatron/bridge/training/utils/flop_utils.py Outdated
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 27, 2026

Review

Bug — effective_seq_length_sq is computed but never used (flop_utils.py)

seqlen_squared_sum is accepted, converted to effective_seq_length_sq (lines 116–119), but no attention calculation references it. All quadratic attention terms still use effective_seq_length. The entire seqlen_squared_sum accumulation pipeline (vlm_step.pytrain.pytrain_utils.pyflop_utils.py) is dead code. See inline comment for details.

Missing tests

The PR description says "Added test_flops_calculation.py with 13 verification tests" but no test file is included in the diff. Were they omitted from the commit?

Comment thread src/megatron/bridge/training/utils/flop_utils.py Outdated
@yaoyu-33
Copy link
Copy Markdown
Contributor

Bug: seqlen_squared_sum plumbing is dead — attention FLOPs underestimated for variable-length batches

In src/megatron/bridge/training/utils/flop_utils.py, effective_seq_length_sq is computed at lines 116-119 from seqlen_squared_sum, but is never used anywhere else in the function.

The attention core term in both the MLA branch (lines 480-486) and the MHA/GQA / SWA branch (lines 526-538) is written as ... * effective_seq_length / 2 ..., and the outer expression multiplies by seqlen_sum (line 618). The product is:

seqlen_sum × effective_seq_length / 2  =  (Σ Lᵢ) × mean(Lᵢ) / 2  =  batch_size × mean(L)² / 2

But the physically correct attention cost is Σ Lᵢ² / 2 = batch_size × mean(L²) / 2. By Jensen's inequality mean(L)² ≤ mean(L²), so whenever sequence lengths actually vary across the micro-batches accumulated in vlm_step.py (lines 439-440) — exactly the case this PR adds plumbing for — the attention FLOPs are systematically under-counted. For fixed-length batches the two are equal, which is why a homogeneous-batch test wouldn't surface this.

Suggested fix

In the core-attention terms, replace effective_seq_length with effective_seq_length_sq / effective_seq_length (equivalently seqlen_squared_sum / seqlen_sum). After the outer multiplication by seqlen_sum, the attention contribution becomes seqlen_squared_sum × (...) / 2, which is what seqlen_squared_sum was added to enable.

Concretely:

  • MLA branch (lines 480, 486): effective_seq_lengtheffective_seq_length_sq / effective_seq_length
  • MHA/GQA full_core (line 538) and SWA full_core (line 526): same substitution; swa_core (line 527) stays as-is since SWA attention is linear in seq_len once the window saturates.

Worth adding a unit test that feeds heterogeneous (seqlen_sum, seqlen_squared_sum) and asserts the attention term scales with Σ L², not (Σ L)²/B.

@yaoyu-33 yaoyu-33 added waiting-on-customer Waiting on the original author to respond and removed needs-review PR is ready for code review and waiting on a reviewer labels Apr 28, 2026
SophusDavid and others added 3 commits April 28, 2026 11:57
1. Wire seqlen_squared_sum into core-attn via core_attn_seq_factor (was dead
code).
2. Refactor vit_flops to take (cfg, batch_size, num_patches) so ViT
hyperparameters are read from cfg.model.vision_config.
3. Align training_log TFLOP/s/GPU with the main-loop FLOPS accumulator for
variable-length batches.
4. Tighten seqlen_sum typing to int | None and add some unit tests.

Signed-off-by: dingtianwei.dtw <[email protected]>
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label Apr 28, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label Apr 30, 2026
@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 1, 2026

/ok to test f5d422b

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 2, 2026
yaoyu-33
yaoyu-33 previously approved these changes May 3, 2026
@yaoyu-33 yaoyu-33 added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 3, 2026
@yaoyu-33
Copy link
Copy Markdown
Contributor

yaoyu-33 commented May 3, 2026

@SophusDavid : please run precommit locally to resolve lint issues

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond waiting-on-maintainers Waiting on maintainers to respond and removed waiting-on-customer Waiting on the original author to respond labels May 3, 2026
@cuichenx
Copy link
Copy Markdown
Contributor

cuichenx commented May 5, 2026

/ok to test 624d5a7

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 5, 2026

/ok to test 624d5a7

@cuichenx, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@weijiac0619
Copy link
Copy Markdown
Contributor

/ok to test 4a7f00b

@weijiac0619
Copy link
Copy Markdown
Contributor

@SophusDavid could you fix the failed test? Thank you!

@SophusDavid
Copy link
Copy Markdown
Author

@SophusDavid : please run precommit locally to resolve lint issues

Ok! resolved in commit dfc2fd4

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:perf Performance optimizations and benchmarking bug Something isn't working community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] VLM TFLOPS calculation is inaccurate: uses config seq_length instead of actual FA sequence length, and omits ViT encoder FLOPs

5 participants