Skip to content

Conversation

@yonigozlan
Copy link
Member

@yonigozlan yonigozlan commented Jul 22, 2025

What does this PR do?

As discussed internally, this PR starts the process to make fast image processors the default in 🤗Transformers!

When instantiating a Processor or an Image Processor via AutoProcessor.from_pretrained or AutoImageProcessor.from_pretrained with a checkpoint using a Qwen2VLImageProcessor, the behavior will now be, to load Qwen2VLImageProcessorFast even if the processor was saved with a slow Qwen2VLImageProcessor originally.

For instance:
Old behavior:

>> from transformers import AutoProcessor
>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
>> print(type(processor.image_processor))
<class 'transformers.models.qwen2_vl.image_processing_qwen2_vl.Qwen2VLImageProcessor'>

New behavior:

>> from transformers import AutoProcessor
>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
>> print(type(processor.image_processor))
"""The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release."""
<class 'transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.Qwen2VLImageProcessorFast'>

( The warning is a warning_once)

This PR also comes with a long overdue refactor (which should be 100% compatible with the slow image processor of qwen2 vl, and fix some existing inconsistencies with the fast one). Cc @zucchini-nlp for that :)

🚨The processed images in output between the slow and fast image processor are slightly different! This is expected as torchvision and PiL image processing functions are not fully equivalent.
Users can still force the use of a slow processor by loading the processor with use_fast=False

>> from transformers import AutoProcessor
>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", use_fast=False)
>> print(type(processor.image_processor))
<class 'transformers.models.qwen2_vl.image_processing_qwen2_vl.Qwen2VLImageProcessor'>

Here are some comparison between fast and slow processors with this refactor.
Mixed various means images of different sizes are included in input. The images used for these benchmarks is this one

Summary of the summary: up to 30x speedup, between 5e-8 and 3e-3 average output pixel differences depending on the processing parameters and input image sizes

Summary: Max Output Difference vs. Slow processor

This table shows the maximum difference at any single point between the output tensors of the Fast processors and the Slow processor.

Batch Size 1 4 8 16 32 64
('mixed_various', 'Fast_cpu_grouping_disabled') 2.384e-07 0.0292 0.0292 0.0292 0.0292 0.0292
('mixed_various', 'Fast_cpu_grouping_enabled') 2.384e-07 0.0292 0.0292 0.0292 0.0292 0.0292
('uniform_1024x1024', 'Fast_cpu_grouping_disabled') 0.0292 0.0292 0.0292 0.0292 0.0292 0.0292
('uniform_1024x1024', 'Fast_cpu_grouping_enabled') 0.0292 0.0292 0.0292 0.0292 0.0292 0.0292
('uniform_224x224', 'Fast_cpu_grouping_disabled') 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07
('uniform_224x224', 'Fast_cpu_grouping_enabled') 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07
('uniform_512x512', 'Fast_cpu_grouping_disabled') 0.01501 0.01501 0.01501 0.01501 0.01501 0.01501
('uniform_512x512', 'Fast_cpu_grouping_enabled') 0.01501 0.01501 0.01501 0.01501 0.01501 0.01501
('mixed_various', 'Fast_cuda_grouping_disabled') 2.384e-07 0.09005 0.09005 0.09005 0.09005 0.09005
('mixed_various', 'Fast_cuda_grouping_enabled') 2.384e-07 0.09005 0.09005 0.09005 0.09005 0.09005
('uniform_1024x1024', 'Fast_cuda_grouping_disabled') 0.04266 0.04266 0.04266 0.04266 0.04266 0.04266
('uniform_1024x1024', 'Fast_cuda_grouping_enabled') 0.04266 0.04266 0.04266 0.04266 0.04266 0.04266
('uniform_224x224', 'Fast_cuda_grouping_disabled') 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07
('uniform_224x224', 'Fast_cuda_grouping_enabled') 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07 2.384e-07
('uniform_512x512', 'Fast_cuda_grouping_disabled') 0.09005 0.09005 0.09005 0.09005 0.09005 0.09005
('uniform_512x512', 'Fast_cuda_grouping_enabled') 0.09005 0.09005 0.09005 0.09005 0.09005 0.09005

Summary: Mean Absolute Output Difference vs. Slow processor

This table shows the mean absolute difference between the output tensors of the Fast processors and the Slow processor for each configuration and image scenario.

Batch Size 1 4 8 16 32 64
('mixed_various', 'Fast_cpu_grouping_disabled') 5.315e-08 7.732e-05 7.452e-05 7.956e-05 7.892e-05 8e-05
('mixed_various', 'Fast_cpu_grouping_enabled') 5.315e-08 7.732e-05 7.452e-05 7.956e-05 7.892e-05 8e-05
('uniform_1024x1024', 'Fast_cpu_grouping_disabled') 9.615e-05 9.615e-05 9.615e-05 9.615e-05 9.615e-05 9.615e-05
('uniform_1024x1024', 'Fast_cpu_grouping_enabled') 9.615e-05 9.615e-05 9.615e-05 9.615e-05 9.615e-05 9.615e-05
('uniform_224x224', 'Fast_cpu_grouping_disabled') 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08
('uniform_224x224', 'Fast_cpu_grouping_enabled') 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08
('uniform_512x512', 'Fast_cpu_grouping_disabled') 2.832e-05 2.832e-05 2.832e-05 2.832e-05 2.832e-05 2.832e-05
('uniform_512x512', 'Fast_cpu_grouping_enabled') 2.832e-05 2.832e-05 2.832e-05 2.832e-05 2.832e-05 2.832e-05
('mixed_various', 'Fast_cuda_grouping_disabled') 5.315e-08 0.002611 0.002679 0.002686 0.0027 0.002701
('mixed_various', 'Fast_cuda_grouping_enabled') 5.315e-08 0.002611 0.002679 0.002686 0.0027 0.002701
('uniform_1024x1024', 'Fast_cuda_grouping_disabled') 0.002783 0.002783 0.002783 0.002783 0.002783 0.002783
('uniform_1024x1024', 'Fast_cuda_grouping_enabled') 0.002783 0.002783 0.002783 0.002783 0.002783 0.002783
('uniform_224x224', 'Fast_cuda_grouping_disabled') 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08
('uniform_224x224', 'Fast_cuda_grouping_enabled') 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08 5.315e-08
('uniform_512x512', 'Fast_cuda_grouping_disabled') 0.002913 0.002913 0.002913 0.002913 0.002913 0.002913
('uniform_512x512', 'Fast_cuda_grouping_enabled') 0.002913 0.002913 0.002913 0.002913 0.002913 0.002913

Time per images:

time_per_image_all_configs

With different image sizes:

time_per_image_all_configs time_per_image_all_configs
time_per_image_all_configs time_per_image_all_configs

Speedups:

speedup_vs_slow

With different image sizes:

speedup_vs_slow speedup_vs_slow
speedup_vs_slow speedup_vs_slow

Cc @qubvel @ArthurZucker @Cyrilvallez

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yaaay, great work, happy to see fast processors being the default 🚀


@auto_docstring
def preprocess(
def _preprocess_videos(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't further maintain videos with all new features and keep a separate fn to preprocess them. WDYT if we feed video to self._preprocess_image and set disable_grouping=False? AFAIK that is the only diff for Qwen

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! removed it in the code. Only change necessary is to add the temporal dimensions only for images and not for video

processed_videos_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size

def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thus should not be removed! It is used by vLLM to infer number of patches and placeholders without an image input. Can you add a tiny comment in docstring as well, so we don't delete it again accidentally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I must have been to eager to delete the old code 😅, will add a comment!

Copy link
Contributor

@qubvel qubvel Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp, hah, you said no one will touch this code 😆 do we plan to have some vLLM integration tests to check the required methods/attributes are still exist?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha my bad, will add some tests for new helpers 😄

Copy link
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this 🤗 huge speed up!!

logger = logging.get_logger(__name__)


FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably give the option to opt out -> make it a form_pretrained arg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker use_fast=False still would work, that's for default behaviour, when use_fast is not provided for from_pretrained

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly as @qubvel said :)

@yonigozlan
Copy link
Member Author

Thanks for the review @qubvel @zucchini-nlp , made the modifications! I also needed to change quite a few processor tests for qwen vls/omni because of the change to fast image processor by default, but should be good now!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, colqwen2, glm4v, qwen2_5_omni, qwen2_5_vl, qwen2_vl

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗

logger = logging.get_logger(__name__)


FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably give the option to opt out -> make it a form_pretrained arg?

@yonigozlan yonigozlan merged commit 17f0210 into huggingface:main Jul 25, 2025
25 checks passed
@yhyang201
Copy link

yhyang201 commented Aug 13, 2025

Hi
I noticed that this PR appears to change the behavior of the Qwen2-VL processor (I haven’t tested Qwen2.5-VL yet, but I suspect it shows a similar pattern).
Using the Qwen2.5-VL demo image (https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg) as an example:
- Before the PR, the processor produced pixel_values with shape [4988, 1176];
- After the PR, the shape becomes [14308, 1176].

In both cases, the model can correctly interpret the image.

Could you share whether this change is intentional and necessary? If model capability remains the same, a shorter/smaller pixel_values (i.e., shorter sequence length) should lead to faster inference.

@yonigozlan
Copy link
Member Author

Hi @yhyang201 , I'm not able to reproduce the issue, could you please provide a snippet to do so? Thanks.

@yhyang201
Copy link

Hi @yonigozlan , I sincerely apologize — I have double-checked and confirmed that this PR does not exhibit the issue I mentioned earlier.
Thank you very much for your contribution, and I’m truly sorry for the inconvenience and for taking up your time.

@yonigozlan
Copy link
Member Author

@yhyang201 no worries 🤗

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…L + Refactor (huggingface#39591)

* init

* Force qwen2VL image proc to fast

* refactor qwen2 vl fast

* fix copies

* Update after PR review and update tests to use return_tensors="pt"

* fix processor tests

* add BC for min pixels/max pixels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants