Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Jun 3, 2025

The tests:

tests/models/internvl/test_video_processor_internvl.py::test_can_compile_fast_video_processor
tests/models/qwen2_vl/test_video_processing_qwen2_vl.py::test_can_compile_fast_video_processor

are failing on AMD because the version of torch.compile we are using does not properly graph the F.resize function when input is uint8. For now, we fix this by wrapping calls to F.resize in a compile-friendly function.

This fixes the first test right of the bat, for the second we re-directed the .resize call in qwen2 processor to use the parent class's version.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@mht-sharma mht-sharma self-requested a review June 4, 2025 10:25
@remi-or remi-or force-pushed the img-process-compile-fix branch from 2feef65 to 9dfb902 Compare June 4, 2025 12:19
@remi-or remi-or marked this pull request as ready for review June 4, 2025 12:35
@mht-sharma
Copy link
Contributor

@zucchini-nlp for review

"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
# This is a workaround to avoid a bug in torch.compile when dealing with uint8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a link to the issue? And does it affect both Nvidia and Amd

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did not happen on the A100 when I tested it. There is no issue opened, but I do have a reproducible script: https://gist.github.com/remi-or/33359e5435c4de74d8146d85ba50e485

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that it didn't happen on Nvidia might be because of different torch version, from what I have seen it's because when the compile graph sees (clamp to uint8 range) -> (round) -> (convert to uint8) it skips the clamping (probably thinks it's done in conversion) and we get values like -2 converted to 254. Hence the use of masked_fill to avoid this behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify that the error occurs on AMD only as far as we know in the comment? Could be useful for future debugging

Copy link
Contributor

@mht-sharma mht-sharma Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this issue is specific to AMD and occurs on both the latest stable and nightly releases (which we're likely using since we're building from source), I would suggest:

  • File an issue in the PyTorch repo or with the AMD team, and link it here for future reference.
  • Apply the change only for the ROCm path to avoid potential issues with NVIDIA compatibility upstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on the above comment. I cannot reproduce this on NVIDIA with the latest torch. An issue with torch repo as a reference will remind us to revert this workaround once the torch team makes a fix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a is_rocm_platform to the condition, a on AMD ... in the comment and here is the issue: pytorch/pytorch#155209

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge this if it can fix compile errors with AMD, let's just make sure that it doesn't break compile equivalence tests (test_can_compile_fast_image_processor/test_can_compile_fast_video_processor) for other fast image processor/video processors.

You can run this utility from your transformers folder to check:

import argparse
import concurrent.futures
import os
import subprocess

COLOR_RED = "\033[91m"
COLOR_GREEN = "\033[92m"
COLOR_YELLOW = "\033[93m"
COLOR_CYAN = "\033[96m"
COLOR_RESET = "\033[0m"


def colored(text, color_code):
    return f"{color_code}{text}{COLOR_RESET}"


def run_test_file(test_file):
    command = f"RUN_SLOW=1 python -m pytest {test_file} -k 'test_can_compile_fast_image_processor or test_can_compile_fast_video_processor'"
    print(colored(f"Worker running: {command}", COLOR_CYAN))
    process = subprocess.run(command, shell=True, capture_output=True, text=True)
    failed_in_file = False
    for line in process.stdout.splitlines():
        if "FAILED" in line:
            failed_in_file = True
            break

    if failed_in_file:
        print(colored(f"Test file {test_file} finished: {COLOR_RED}FAILED{COLOR_RESET}", COLOR_YELLOW))
        return [
            colored(f"{test_file}: {line.strip()}", COLOR_RED)
            for line in process.stdout.splitlines()
            if "FAILED" in line
        ]
    else:
        print(colored(f"Test file {test_file} finished: {COLOR_GREEN}PASSED{COLOR_RESET}", COLOR_YELLOW))
        return []


def find_test_files(search_dir="tests/models"):
    test_files = []
    for root, _, files in os.walk(search_dir):
        for file in files:
            if (
                file.startswith("test_image_processing") or file.startswith("test_video_processing")
            ) and file.endswith(".py"):
                test_files.append(os.path.join(root, file))
    return test_files


def main():
    parser = argparse.ArgumentParser(description=colored("Run image processing tests in parallel.", COLOR_GREEN))
    parser.add_argument(
        "-w",
        "--workers",
        type=int,
        default=8,
        help=colored("Maximum number of concurrent test workers (default: 8).", COLOR_YELLOW),
    )
    args = parser.parse_args()
    max_workers = args.workers

    test_files = find_test_files()

    if not test_files:
        print(colored("No test_image_processing files found.", COLOR_YELLOW))
        return

    failed_tests_all = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(run_test_file, test_file) for test_file in test_files]
        for future in concurrent.futures.as_completed(futures):
            failed_tests_all.extend(future.result())

    if failed_tests_all:
        print(colored("\n--- Summary of Failed Tests ---", COLOR_RED))
        for failed_test in failed_tests_all:
            print(failed_test)
    else:
        print(colored("\nAll tests passed!", COLOR_GREEN))


if __name__ == "__main__":
    main()

Comment on lines +299 to +302
image = image.float() / 256
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
image = image * 256
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why divide by 256 and not 255?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried with 255 and got a numerical difference, that I did not get with 256:
Script: https://gist.github.com/remi-or/eb8936ca093d54c186fb5b67f15334eb
Output:

Max difference with 255: 1.0
Max difference with 256: 0.0

Comment on lines +146 to +149
stacked_videos = self.resize(
image=stacked_videos,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It will be fixed in a side branch though, not yet in main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh didn't catch that thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, we can fix qwen2-vl image processor as well, it has the same issue

"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
# This is a workaround to avoid a bug in torch.compile when dealing with uint8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify that the error occurs on AMD only as far as we know in the comment? Could be useful for future debugging

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, as long as we make it an AMD-specific workaround and add a comment with reference issue

"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
# This is a workaround to avoid a bug in torch.compile when dealing with uint8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on the above comment. I cannot reproduce this on NVIDIA with the latest torch. An issue with torch repo as a reference will remind us to revert this workaround once the torch team makes a fix

Comment on lines +146 to +149
stacked_videos = self.resize(
image=stacked_videos,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, we can fix qwen2-vl image processor as well, it has the same issue

@remi-or
Copy link
Collaborator Author

remi-or commented Jun 5, 2025

@yonigozlan just ran your script and All tests passed! (added two commits to fix llava_next and bridgetower )

@remi-or remi-or force-pushed the img-process-compile-fix branch from 9dfb902 to e4bae98 Compare June 5, 2025 12:15
Copy link
Contributor

@mht-sharma mht-sharma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

@mht-sharma mht-sharma merged commit 18f810f into huggingface:amd-hf-ci-branch Jun 6, 2025
20 checks passed
@remi-or remi-or mentioned this pull request Jun 23, 2025
remi-or added a commit that referenced this pull request Jun 26, 2025
* Image processor compile fix (#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Image processor compile fix (huggingface#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <[email protected]>

---------

Co-authored-by: Mohit Sharma <[email protected]>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants