perf(fix): accumulate per-microbatch FLOPS metadata for accurate…#3529
perf(fix): accumulate per-microbatch FLOPS metadata for accurate…#3529SophusDavid wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
Conversation
… VLM TFLOPS - Replace static cfg.model.seq_length with accumulated actual padded lengths from each micro-batch (_flops_seqlen_sum/sq_sum). - Add ViT encoder FLOPS via _flops_vision_patches from visual grid THW. - Scale per-rank accumulators by dp_size for global FLOPS estimation. - Revert train_step return signature; zero API surface changes. - Also support trace_prefix in pytorch profiler initialization. Signed-off-by: dingtianwei.dtw <[email protected]>
0a5474e to
7f1ff31
Compare
|
/claude review |
ReviewBug —
Missing tests The PR description says "Added |
Bug:
|
1. Wire seqlen_squared_sum into core-attn via core_attn_seq_factor (was dead code). 2. Refactor vit_flops to take (cfg, batch_size, num_patches) so ViT hyperparameters are read from cfg.model.vision_config. 3. Align training_log TFLOP/s/GPU with the main-loop FLOPS accumulator for variable-length batches. 4. Tighten seqlen_sum typing to int | None and add some unit tests. Signed-off-by: dingtianwei.dtw <[email protected]>
…uared_sum" . Signed-off-by: dingtianwei.dtw <[email protected]>
84c7535 to
f5d422b
Compare
|
/ok to test f5d422b |
|
@SophusDavid : please run precommit locally to resolve lint issues |
|
/ok to test 624d5a7 |
@cuichenx, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test 4a7f00b |
|
@SophusDavid could you fix the failed test? Thank you! |
Signed-off-by: dingtianwei.dtw <[email protected]>
Ok! resolved in commit dfc2fd4 |
perf(training): improve VLM TFLOPS calculation with actual sequence length and ViT
What does this PR do ?
Fixes inaccurate TFLOPS reporting for VLM training where the reported FLOPS was based on the statically configured
cfg.model.seq_lengthinstead of the dynamically padded sequence lengths that FlashAttention actually processes. This led to significant overestimation (up to ~2.6x for typical VLM SFT scenarios with short sequences + images).The PR introduces a per-micro-batch accumulator pattern that captures the real padded sequence length and vision patch count from each
forward_stepinvocation, scales them by data-parallel size for global estimation, and feeds them into an extendednum_floating_point_operations()API.Closes #3498
Changelog
vlm_step.py
_flops_seqlen_sumand_flops_seqlen_sq_sumacross micro-batches usingtokens.shape[1](actual padded length)._flops_vision_patchesfromvisual_inputs.image_grid_thw/video_grid_thw."total_tokens"topacked_seq_paramsfor downstream visibility.train.py
_flops_*accumulators to 0 before eachwrapped_train_stepcall.dp_sizefor global FLOPS, and pass effectiveseq_lengthtotraining_log.seq_lengthreturn value fromtrain_step(reverting to original 8-element return signature).flop_utils.py
vit_flops()helper for Vision Transformer encoder FLOPS (bidirectional attention + GELU MLP + patch merger).num_floating_point_operations()signature with optionalseqlen_sum,seqlen_squared_sum, andnum_vision_patchesparameters._compute_vit_flops()that derives per-image patch count from total batch patches for correct quadratic attention scaling.train_utils.py
training_log()now acceptsseq_lengthparameter and uses accumulated_flops_vision_patches(scaled bydata_parallel_size) for consistent per-log FLOPS recomputation.Before your PR is "Ready for review"
Pre checks:
test_flops_calculation.pywith 13 verification tests covering VLM and non-VLM scenarios.Additional Information