mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Internvl fix (#38946)
* Image processor compile fix (#38540) * Added a compile-friendly versiom of resize to BaseImgProcessorFast * Changed qwen2 processor to use its parent class .resize * Style * underlined issue only happens on AMD w/ comment and bool check * Fixed some utils functions * Fixed the same issue for bridgetower * Fixed the same issue for llava_next * Repo consistency for llava onevision * Update src/transformers/image_processing_utils_fast.py Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com> --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com> * Added an Expectation to an internvl test * Made qwen2_vl use the resize method of its parent clas * Changed to torch.where --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
This commit is contained in:
parent
f85b47d1b8
commit
25c44d4b68
@ -49,6 +49,7 @@ from .utils import (
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import is_rocm_platform
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -280,8 +281,34 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
||||
f" {size}."
|
||||
)
|
||||
# This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
|
||||
# Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
|
||||
# TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
|
||||
if torch.compiler.is_compiling() and is_rocm_platform():
|
||||
return self.compile_friendly_resize(image, new_size, interpolation, antialias)
|
||||
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
@staticmethod
|
||||
def compile_friendly_resize(
|
||||
image: "torch.Tensor",
|
||||
new_size: tuple[int, int],
|
||||
interpolation: Optional["F.InterpolationMode"] = None,
|
||||
antialias: bool = True,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
|
||||
"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.float() / 256
|
||||
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
image = image * 256
|
||||
image = torch.where(image > 255, 255, image)
|
||||
image = torch.where(image < 0, 0, image)
|
||||
image = image.round().to(torch.uint8)
|
||||
else:
|
||||
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
return image
|
||||
|
||||
def rescale(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
|
@ -165,13 +165,18 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
shorter = size.shortest_edge
|
||||
longer = int(1333 / 800 * shorter)
|
||||
output_size = get_resize_output_image_size(
|
||||
output_height, output_width = get_resize_output_image_size(
|
||||
image,
|
||||
shorter=shorter,
|
||||
longer=longer,
|
||||
size_divisor=size_divisor,
|
||||
)
|
||||
return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
|
||||
return super().resize(
|
||||
image=image,
|
||||
size=SizeDict(height=output_height, width=output_width),
|
||||
interpolation=interpolation,
|
||||
antialias=antialias,
|
||||
)
|
||||
|
||||
def center_crop(
|
||||
self,
|
||||
|
@ -137,7 +137,11 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
resized_image = self.resize(
|
||||
image=image,
|
||||
size=SizeDict(height=new_height, width=new_width),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
|
||||
return resized_image
|
||||
|
||||
|
@ -142,7 +142,11 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
resized_image = self.resize(
|
||||
image=image,
|
||||
size=SizeDict(height=new_height, width=new_width),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
|
||||
return resized_image
|
||||
|
||||
|
@ -203,8 +203,10 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
min_pixels=size["shortest_edge"],
|
||||
max_pixels=size["longest_edge"],
|
||||
)
|
||||
stacked_images = F.resize(
|
||||
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
|
||||
stacked_images = self.resize(
|
||||
image=stacked_images,
|
||||
size=SizeDict(height=resized_height, width=resized_width),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
@ -250,8 +250,10 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
stacked_videos = F.resize(
|
||||
stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
|
||||
stacked_videos = self.resize(
|
||||
image=stacked_videos,
|
||||
size=SizeDict(height=resized_height, width=resized_width),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
@ -705,6 +705,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
("xpu", 3): torch.tensor([-9.8750, -0.5703, 1.4297, -10.3125, -10.3125], dtype=torch.float16),
|
||||
("cuda", 7): torch.tensor([-9.8750, -0.4861, 1.4648, -10.3359, -10.3359], dtype=torch.float16),
|
||||
("cuda", 8): torch.tensor([-9.8906, -0.4995, 1.4473, -10.3359, -10.3438], dtype=torch.float16),
|
||||
("rocm", (9, 5)): torch.tensor([ -9.8906, -0.4976, 1.4502, -10.3359, -10.3438], dtype=torch.float16),
|
||||
}
|
||||
) # fmt: skip
|
||||
expected_logits = torch.tensor(expected_logits_all.get_expectation(), dtype=torch.float16)
|
||||
|
Loading…
Reference in New Issue
Block a user