Skip to content

[bug] OOM observed with fp4_param_gather=true for "DEEPSEEK_V3_PRETRAIN_CONFIG_B200_NVFP4_V2" #3597

@malay-nagda

Description

@malay-nagda

Problem

Out of Memory (OOM) observed for pre-training DEEPSEEK_V3_PRETRAIN_CONFIG_B200_NVFP4_V2 with fp4_param_gather=true

Minimal repro

python3 -m venv bridge_venv
source bridge_venv/bin/activate

pip3 install git+https://github.com/NVIDIA-NeMo/Run.git
git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git
cd Megatron-Bridge

git checkout fcb4c287b312c78585d114c5a6a71574dbfc290f

python3 scripts/performance/setup_experiment.py --model_family_name deepseek --model_recipe_name deepseek_v3 --num_gpus 256 --hf_token <HF_TOKEN> --account <SLURM_ACCOUNT> --partition <SLURM_PARTITION> --time_limit 01:00:00 --container_image nvcr.io/nvidian/nemo:nightly-2026-04-28 --gpu gb200 --compute_dtype nvfp4 mixed_precision.fp4_param_gather=False

Expected behavior

Out of Memory (OOM) error after task initialization on the cluster.

Affected area

area:model area:perf

Regression?

Yes

Environment

commit- fcb4c287b312c78585d114c5a6a71574dbfc290f
container- nvcr.io/nvidian/nemo:nightly-2026-04-28
GPU: B200

Logs

p4_v2_perf/0 [rank54]: Traceback (most recent call last):
p4_v2_perf/0 [rank54]:   File "Megatron-Bridge/scripts/performance/run_script.py", line 131, in <module>
p4_v2_perf/0 [rank54]:     main()
p4_v2_perf/0 [rank54]:   File "Megatron-Bridge/scripts/performance/run_script.py", line 123, in main
p4_v2_perf/0 [rank54]:     pretrain(config=recipe, forward_step_func=forward_step_func)
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/src/megatron/bridge/utils/decorators.py", line 39, in wrapper
p4_v2_perf/0 [rank54]:     return func(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 98, in pretrain
p4_v2_perf/0 [rank54]:     _pretrain(state=state, forward_step_func=forward_step_func, callback_manager=callback_manager)
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 142, in _pretrain
p4_v2_perf/0 [rank54]:     train(
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/src/megatron/bridge/training/train.py", line 445, in train
p4_v2_perf/0 [rank54]:     ) = wrapped_train_step(
p4_v2_perf/0 [rank54]:         ^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/src/megatron/bridge/training/train.py", line 841, in train_step
p4_v2_perf/0 [rank54]:     losses_reduced = forward_backward_func(
p4_v2_perf/0 [rank54]:                      ^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1506, in forward_backward_pipelining_with_interleaving
p4_v2_perf/0 [rank54]:     output_tensor, _ = forward_backward_helper_wrapper(
p4_v2_perf/0 [rank54]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1370, in forward_backward_helper_wrapper
p4_v2_perf/0 [rank54]:     return combined_1f1b_schedule_for_interleaved_pipelining(
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/combined_1f1b.py", line 204, in combined_1f1b_schedule_for_interleaved_pipelining
p4_v2_perf/0 [rank54]:     output_tensor, num_tokens, input_tensor_grad = combined_forward_backward_step(
p4_v2_perf/0 [rank54]:                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/combined_1f1b.py", line 397, in combined_forward_backward_step
p4_v2_perf/0 [rank54]:     output_tensor = type(f_schedule_plan or b_schedule_plan).run(
p4_v2_perf/0 [rank54]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/common/model_chunk_schedule_plan.py", line 507, in run
p4_v2_perf/0 [rank54]:     f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
p4_v2_perf/0 [rank54]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/common/model_chunk_schedule_plan.py", line 239, in run
p4_v2_perf/0 [rank54]:     f_input = f_layer.mlp.forward(f_input)
p4_v2_perf/0 [rank54]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/utils.py", line 205, in forward
p4_v2_perf/0 [rank54]:     return self._forward(*inputs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/pipeline_parallel/utils.py", line 218, in _forward
p4_v2_perf/0 [rank54]:     data = self.forward_func(*data)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/gpt/fine_grained_callables.py", line 41, in wrapped_func
p4_v2_perf/0 [rank54]:     return method_ref()(*args, **kwarg)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/gpt/fine_grained_callables.py", line 312, in forward_impl
p4_v2_perf/0 [rank54]:     return self.submodule(self, *args)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/gpt/fine_grained_callables.py", line 569, in submodule_moe_forward
p4_v2_perf/0 [rank54]:     expert_output, _ = layer.mlp.routed_experts_compute(dispatched_tokens, dispatched_probs)
p4_v2_perf/0 [rank54]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 507, in routed_experts_compute
p4_v2_perf/0 [rank54]:     expert_output, mlp_bias = apply_module(self.experts)(
p4_v2_perf/0 [rank54]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
p4_v2_perf/0 [rank54]:     return self._call_impl(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
p4_v2_perf/0 [rank54]:     return forward_call(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/experts.py", line 382, in forward
p4_v2_perf/0 [rank54]:     bias_act_output = self.activation_checkpoint.checkpoint(
p4_v2_perf/0 [rank54]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/tensor_parallel/random.py", line 729, in checkpoint
p4_v2_perf/0 [rank54]:     outputs = CheckpointWithoutOutputFunction.apply(run_function, self, *args)
p4_v2_perf/0 [rank54]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 583, in apply
p4_v2_perf/0 [rank54]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/tensor_parallel/random.py", line 670, in forward
p4_v2_perf/0 [rank54]:     outputs = run_function(*args)
p4_v2_perf/0 [rank54]:               ^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/experts.py", line 285, in bias_act_func
p4_v2_perf/0 [rank54]:     intermediate_parallel = weighted_bias_swiglu_impl(
p4_v2_perf/0 [rank54]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py", line 249, in weighted_bias_swiglu_impl
p4_v2_perf/0 [rank54]:     output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store)
p4_v2_perf/0 [rank54]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 583, in apply
p4_v2_perf/0 [rank54]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py", line 199, in forward
p4_v2_perf/0 [rank54]:     return weighted_swiglu(input, weights)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 985, in compile_wrapper
p4_v2_perf/0 [rank54]:     return fn(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py", line 44, in weighted_swiglu
p4_v2_perf/0 [rank54]:     @jit_fuser
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1220, in _fn
p4_v2_perf/0 [rank54]:     return fn(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1151, in forward
p4_v2_perf/0 [rank54]:     return compiled_fn(full_args)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 536, in runtime_wrapper
p4_v2_perf/0 [rank54]:     all_outs = call_func_at_runtime_with_args(
p4_v2_perf/0 [rank54]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 134, in call_func_at_runtime_with_args
p4_v2_perf/0 [rank54]:     out = normalize_as_list(f(args))
p4_v2_perf/0 [rank54]:                             ^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2163, in __call__
p4_v2_perf/0 [rank54]:     return self.compiled_fn(*args, **kwargs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 710, in wrapper
p4_v2_perf/0 [rank54]:     return compiled_fn(runtime_args)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 930, in inner_fn
p4_v2_perf/0 [rank54]:     outs = compiled_fn(args)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 638, in __call__
p4_v2_perf/0 [rank54]:     return self.current_callable(inputs)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 3331, in run
p4_v2_perf/0 [rank54]:     out = model(new_inputs)
p4_v2_perf/0 [rank54]:           ^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]:   File "/tmp/torchinductor_okoenig/ya/cyaqqe6spgzohk3vewriifskcuery5wi6pbioxgk2fuhxezzrjld.py", line 122, in call
p4_v2_perf/0 [rank54]:     buf0 = empty_strided_cuda((s17, 2048), (2048, 1), torch.bfloat16)
p4_v2_perf/0 [rank54]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p4_v2_perf/0 [rank54]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 6 has a total capacity of 178.35 GiB of which 117.19 MiB is free. Including non-PyTorch memory, this process has 178.21 GiB memory in use. Of the allocated memory 173.41 GiB is allocated by PyTorch, and 559.08 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Metadata

Metadata

Labels

area:perfPerformance optimizations and benchmarkingbugSomething isn't workingneeds-triageNew item needs classification and ownership

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions