Internvl fix (#38946)

* Image processor compile fix (#38540)

* Added a compile-friendly versiom of resize to BaseImgProcessorFast

* Changed qwen2 processor to use its parent class .resize

* Style

* underlined issue only happens on AMD w/ comment and bool check

* Fixed some utils functions

* Fixed the same issue for bridgetower

* Fixed the same issue for llava_next

* Repo consistency for llava onevision

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>

* Added an Expectation to an internvl test

* Made qwen2_vl use the resize method of its parent clas

* Changed to torch.where

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
This commit is contained in:
Rémi Ouazan 2025-06-26 13:44:59 +02:00 committed by GitHub
parent f85b47d1b8
commit 25c44d4b68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 53 additions and 8 deletions

View File

@ -49,6 +49,7 @@ from .utils import (
is_vision_available,
logging,
)
from .utils.import_utils import is_rocm_platform
if is_vision_available():
@ -280,8 +281,34 @@ class BaseImageProcessorFast(BaseImageProcessor):
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
# This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
# Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
# TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
if torch.compiler.is_compiling() and is_rocm_platform():
return self.compile_friendly_resize(image, new_size, interpolation, antialias)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
@staticmethod
def compile_friendly_resize(
image: "torch.Tensor",
new_size: tuple[int, int],
interpolation: Optional["F.InterpolationMode"] = None,
antialias: bool = True,
) -> "torch.Tensor":
"""
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
"""
if image.dtype == torch.uint8:
image = image.float() / 256
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
image = image * 256
image = torch.where(image > 255, 255, image)
image = torch.where(image < 0, 0, image)
image = image.round().to(torch.uint8)
else:
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
return image
def rescale(
self,
image: "torch.Tensor",

View File

@ -165,13 +165,18 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size.shortest_edge
longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(
output_height, output_width = get_resize_output_image_size(
image,
shorter=shorter,
longer=longer,
size_divisor=size_divisor,
)
return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
return super().resize(
image=image,
size=SizeDict(height=output_height, width=output_width),
interpolation=interpolation,
antialias=antialias,
)
def center_crop(
self,

View File

@ -137,7 +137,11 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
resized_image = self.resize(
image=image,
size=SizeDict(height=new_height, width=new_width),
interpolation=interpolation,
)
return resized_image

View File

@ -142,7 +142,11 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
resized_image = self.resize(
image=image,
size=SizeDict(height=new_height, width=new_width),
interpolation=interpolation,
)
return resized_image

View File

@ -203,8 +203,10 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
)
stacked_images = F.resize(
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
stacked_images = self.resize(
image=stacked_images,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)

View File

@ -250,8 +250,10 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
min_pixels=min_pixels,
max_pixels=max_pixels,
)
stacked_videos = F.resize(
stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
stacked_videos = self.resize(
image=stacked_videos,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)

View File

@ -705,6 +705,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
("xpu", 3): torch.tensor([-9.8750, -0.5703, 1.4297, -10.3125, -10.3125], dtype=torch.float16),
("cuda", 7): torch.tensor([-9.8750, -0.4861, 1.4648, -10.3359, -10.3359], dtype=torch.float16),
("cuda", 8): torch.tensor([-9.8906, -0.4995, 1.4473, -10.3359, -10.3438], dtype=torch.float16),
("rocm", (9, 5)): torch.tensor([ -9.8906, -0.4976, 1.4502, -10.3359, -10.3438], dtype=torch.float16),
}
) # fmt: skip
expected_logits = torch.tensor(expected_logits_all.get_expectation(), dtype=torch.float16)