Skip to content

feat: GRPO + SFT Dtensor support for multimodal training #712

Merged
terrykong merged 32 commits intomainfrom
rohit/sft_vlm
Aug 19, 2025
Merged

feat: GRPO + SFT Dtensor support for multimodal training #712
terrykong merged 32 commits intomainfrom
rohit/sft_vlm

Conversation

@rohitrango
Copy link
Copy Markdown
Contributor

@rohitrango rohitrango commented Jul 22, 2025

What does this PR do ?

Adds image / video VLM support for supervised finetuning and GRPO using dtensor policy. Solves #85

Tested models:

  • Qwen2VL / Qwen2.5VL
  • Llava 1.5 / Llava Next / Llava Next Video / Llava OneVision
  • Huggingface SmolVLM2-2.2B-Instruct
  • Gemma3 4B

Tested datasets:

  • Geometry3k
  • CLEVR
  • RefCOCO

🔪 Sharp Edges

Although training runs converge, logprob error between vllm and hf model is higher than 1.05 consistently. Issue tracked in #793 .

Edit: Only in Gemma3. logprob issue is fixed in Llava, SmolVLM, Qwen2, 2.5VL

Usage

uv run examples/run_sft.py --config examples/configs/sft_clevr.yaml  cluster.gpus_per_node=4
uv run examples/run_vlm_grpo.py cluster.gpus_per_node=4

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

@rohitrango rohitrango changed the base branch from rohit/vlm_grpo to main July 29, 2025 01:49
@rohitrango rohitrango changed the title feat: SFT support for multimodal training (VLM) feat: GRPO + SFT support for multimodal training Jul 29, 2025
@rohitrango rohitrango marked this pull request as ready for review July 29, 2025 18:30
@github-actions
Copy link
Copy Markdown

❌ Submodule Fast-Forward Check Failed

Check based on commit: cc2986f (PR #712 from rohit/sft_vlm)

❌ Submodules that need attention:

Megatron-LM: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/terrykong/Megatron-LM/commits/2ff0f099ffc30ffd152e3e29e921a1609d00855c/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/terrykong/Megatron-LM/commits/ed5c792f2a8ffe357c871f4547a8fe905a09b835/

NeMo: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/NVIDIA/NeMo/commits/33259f2540af6eef375d43fc48bdcbd7ec490c29/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/NVIDIA/NeMo/commits/0e0894300e09aca042bc07859f660f22858f0a9f/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@github-actions
Copy link
Copy Markdown

❌ Submodule Fast-Forward Check Failed

Check based on commit: 919a7ce (PR #712 from rohit/sft_vlm)

❌ Submodules that need attention:

Megatron-LM: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/terrykong/Megatron-LM/commits/2ff0f099ffc30ffd152e3e29e921a1609d00855c/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/terrykong/Megatron-LM/commits/ed5c792f2a8ffe357c871f4547a8fe905a09b835/

NeMo: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/NVIDIA/NeMo/commits/33259f2540af6eef375d43fc48bdcbd7ec490c29/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/NVIDIA/NeMo/commits/0e0894300e09aca042bc07859f660f22858f0a9f/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@terrykong terrykong changed the title feat: GRPO + SFT support for multimodal training feat: GRPO + SFT Dtensor support for multimodal training Jul 29, 2025
@terrykong
Copy link
Copy Markdown
Collaborator

terrykong commented Jul 29, 2025

copying over the last message from @rohitrango from #655

re: Remaining blockers:

  1. understanding the logprob error: This is something I want to chalk up to how vllm loads multimodal image embeddings in the image processor. For LLM-only, I noted that vllm takes the same list of token_ids (int value list) that the policy consumes (i.e. going through the same text embedding layer, etc.). However, for multimodal images, vllm processes the images internally. There could also be differences in how sampling is done differently. I found the following excerpt from vllm docs https://docs.vllm.ai/en/v0.9.1/usage/v1_guide.html#feature-model

Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling.
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.

I prefer handling this issue in a separate PR (and merging an initial support first) for three four reasons:

  • this discrepancy is isolated to multimodal models only, so a "fix" can be shipped independently
  • multiple VLMs converge on three different datasets despite the apparent discrepancy. It is equivalent to training GRPO with a slightly off-policy model, but it does not seem to be very unstable or destructive to the learning process
  • other PRs break multimodal support regularly (every 2-3 days) and I have to rollback / fix those changes in my PR to make my scripts work. Merging this PR or at least the test cases will prevent other PRs from breaking multimodal support
  • the PR has gotten very big as it is, and adding more fixes will add additional overhead to the review process
  1. PR has now migrated (again) to feat: GRPO + SFT Dtensor support for multimodal training  #712, and is tested on 4 families of multimodal models and 3 datasets. This rollsback the passing around of the vlm_kwargs list throughout the training process and instead proposes a PackedGenericDataItem to handle non-sequence data items (most of them would be multimodal tensors). The single implementation seems to work for multiple multimodal models without any additional modifications to the config.

@github-actions
Copy link
Copy Markdown

❌ Submodule Fast-Forward Check Failed

Check based on commit: 80d9ff5 (PR #712 from rohit/sft_vlm)

❌ Submodules that need attention:

Megatron-LM: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/terrykong/Megatron-LM/commits/2ff0f099ffc30ffd152e3e29e921a1609d00855c/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/terrykong/Megatron-LM/commits/ed5c792f2a8ffe357c871f4547a8fe905a09b835/

NeMo: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/NVIDIA/NeMo/commits/33259f2540af6eef375d43fc48bdcbd7ec490c29/
CURRENT (PR #712 from rohit/sft_vlm): https://github.com/NVIDIA/NeMo/commits/0e0894300e09aca042bc07859f660f22858f0a9f/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@terrykong
Copy link
Copy Markdown
Collaborator

@rohitrango My understanding is that currently the logprob issue may be from input processing not matching inside vllm vs outside. The excerpt you shared is related to sampling, so I think it still remains to be seen whether this is a bug or expected

@terrykong
Copy link
Copy Markdown
Collaborator

As far as keeping up with changes from main, if you rebase and encounter conflicts, it's advised to squash your commits since you'll only hit the conflict once as opposed to several times for each hunk in the branch that has touched that area.

@rohitrango
Copy link
Copy Markdown
Contributor Author

re: keeping up with changes from main , the merge commits are not as big of an issue. The bigger issue is changes that break the multimodal training loop (like adding extra parameters to the dtensor path that is only supported for LLMs, or introducing a model.lm_head somewhere - for VLMs the module would be model.language_model.lm_head), etc.

This basically means I have to debug a working GRPO/SFT training loop every 2 days after merging from main.

The multimodal test cases are expected to block all such changes.

rohitrango and others added 2 commits August 16, 2025 16:21
chtruong814
chtruong814 previously approved these changes Aug 17, 2025
rohitrango and others added 2 commits August 16, 2025 21:30
chtruong814
chtruong814 previously approved these changes Aug 17, 2025
terrykong
terrykong previously approved these changes Aug 18, 2025
Signed-off-by: Yi-Fu Wu <[email protected]>
yfw
yfw previously approved these changes Aug 18, 2025
yfw added 10 commits August 19, 2025 00:28
Signed-off-by: Yi-Fu Wu <[email protected]>
This reverts commit 60b6e82.

Signed-off-by: Yi-Fu Wu <[email protected]>
This reverts commit 2a44965.

Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests CI Relating to CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add multimodal support (Image + Text VLM) to the Huggingface FSDP path + vllm

7 participants