feat: v0 VLM support + GRPO pipeline#655
Conversation
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
assertions for non-vlm keys Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
needs testing on larger machine) Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
|
@terrykong @ashors1 Created a draft PR (duplicate of #521) to see if CI passes on this instead. |
Signed-off-by: rohitrango <[email protected]>
- separated reward functions into separate file (and made composable from YAML files directly) - added RefCOCO task - Ability to freeze huggingface models (language and vision tower) and finegrained freezing using regexes Signed-off-by: rohitrango <[email protected]>
Signed-off-by: Rohit Jena <[email protected]>
| ## this will have consequences for data sharding for VLM models (split along the batch dim but from [start_patch:end_patch]) | ||
| keys_to_concat = [] | ||
|
|
||
| if key in keys_to_concat: |
There was a problem hiding this comment.
keys_to_concat is always empty here?
|
|
||
| if random.random() < img_flip_prob: | ||
| flip = True | ||
| resized_image = resized_image.transpose(Image.FLIP_LEFT_RIGHT) |
There was a problem hiding this comment.
Is this always safe to do for this dataset? Do any captions rely on the positions of the original image (e.g. "A cat sitting to the left of a dog")
ashors1
left a comment
There was a problem hiding this comment.
Thanks for your work on this PR! Two quick comments:
- Could you add your test case to the nightly suite: https://github.com/NVIDIA-NeMo/RL/blob/rohit/vlm_grpo/tests/test_suites/nightly.txt?
- Throughout the code, there are a number of different ways of getting
vlm_keysorvlm_kwargsfrom the data. This seems slightly verbose, but perhaps I don't have a good enough understanding of the code to see why all these different methods are required. Would it be possible to streamline the process of getting the vlm keys/kwargs? If not, could we add some documentation to explain the structure of the vlm keys in the data? That might help to clarify things a bit
| user_message['token_ids'] = message['input_ids'][0] | ||
| # add all keys and values to the user message, and the list of keys | ||
| user_message['vlm_keys'] = [] | ||
| for key, value in message.items(): |
There was a problem hiding this comment.
Are the vlm_keys specific to the dataset? And are they applicable for all messages of that dataset? If so, can this be configured with the dataset? (i.e. in clevr.py and refcoco.py). This seems to assume all keys except for 'input_ids', 'attention_mask' are vlm_keys which seems less safe than if we were explicit about which keys are vlm keys.
There was a problem hiding this comment.
Can we get the whitelist of vlm_keys from processor.image_processor.model_input_names ?
- recycle computed vlm_kwargs for both unflattened and flattened batches - remove potentially unsafe code for flipping images in refcoco - rename `get_vlm_keys_from_clippedpgloss_batch` to `get_vlm_keys_from_flattened_batch` and move it to batched_data_dict.py - add vlm grpo testcase to nightly - improve documentation in CLEVR Signed-off-by: rohitrango <[email protected]>
Signed-off-by: rohitrango <[email protected]>
|
I'm not entirely sure what causes the For the functional test cases, I had to choose higher thresholds for the |
|
|
||
| # Add VL model imports | ||
| try: | ||
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2_5_VLModel |
There was a problem hiding this comment.
Do we need these try/excepts? Can we instead make sure the transformers version we're using has these?
Signed-off-by: rohitrango <[email protected]>
One thought is that we do the preprocessing of the image before calling the model in the dtensor path whereas vllm does preprocessing of the image internally (if I understand what is happening correctly). We may need to make sure whatever preprocessing vllm is doing matches exactly what we're doing in the dtensor path. |
|
This will take me a while to analyse since I don't know exactly how the vllm engine processes the images internally. For the policy, the typical multimodal pipeline is to use the For vllm, the message log is simply reformatted into the format specified in this tutorial. The same sequence of PIL Images is provided to the vLLM frontend. |
Signed-off-by: rohitrango <[email protected]>
|
From today's meeting, the remaining blockers on this PR:
|
|
re: Remaining blockers:
I prefer handling this issue in a separate PR (and merging an initial support first) for
|

What does this PR do ?
Adds VLM support (Qwen2.5-VL) with TP plan, DTensor Policy, vLLM backend, and multiple gpus.
Usage
Convergence
(Training) convergence on 2 H100 GPUs happens in about 60 iterations. (highest possible reward is 5)
Before your PR is "Ready for review"
Pre checks:
Additional Information