feat: GRPO + SFT Dtensor support for multimodal training #712
feat: GRPO + SFT Dtensor support for multimodal training #712
Conversation
|
copying over the last message from @rohitrango from #655 re: Remaining blockers:
I prefer handling this issue in a separate PR (and merging an initial support first) for
|
|
@rohitrango My understanding is that currently the logprob issue may be from input processing not matching inside vllm vs outside. The excerpt you shared is related to sampling, so I think it still remains to be seen whether this is a bug or expected |
|
As far as keeping up with changes from |
|
re: keeping up with changes from This basically means I have to debug a working GRPO/SFT training loop every 2 days after merging from main. The multimodal test cases are expected to block all such changes. |
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: Charlie Truong <[email protected]>
…d_message_log Signed-off-by: rohitrango <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
This reverts commit 60b6e82. Signed-off-by: Yi-Fu Wu <[email protected]>
This reverts commit 2a44965. Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
What does this PR do ?
Adds image / video VLM support for supervised finetuning and GRPO using
dtensorpolicy. Solves #85Tested models:
Tested datasets:
🔪 Sharp Edges
Although training runs converge, logprob error between vllm and hf model is higher than 1.05 consistently. Issue tracked in #793 .
Edit: Only in Gemma3. logprob issue is fixed in Llava, SmolVLM, Qwen2, 2.5VL
Usage
Before your PR is "Ready for review"
Pre checks: