mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix Pan and Scan on batched images Gemma3 (#36864)
* process flattened images in fast image proc * process flattened images in low proc and add tests * remove print * add unbalanced batch test pas image proc * fix integration tests
This commit is contained in:
parent
dd3933dd65
commit
beb9b5b022
@ -35,7 +35,7 @@ from ...image_utils import (
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_nested_list_of_images,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
@ -334,9 +334,9 @@ class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
else self.pan_and_scan_min_ratio_to_activate
|
||||
)
|
||||
|
||||
images_list = make_nested_list_of_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images_list[0]):
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
@ -353,12 +353,12 @@ class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
resample=resample,
|
||||
)
|
||||
if do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images_list[0][0]):
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
@ -366,11 +366,10 @@ class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_pan_and_scan:
|
||||
images_list_and_num_crops = [
|
||||
self._process_images_for_pan_and_scan(
|
||||
images, num_crops = self._process_images_for_pan_and_scan(
|
||||
images=images,
|
||||
do_pan_and_scan=do_pan_and_scan,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
@ -379,15 +378,11 @@ class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for images in images_list
|
||||
]
|
||||
images_list = [images for images, _ in images_list_and_num_crops]
|
||||
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
||||
|
||||
else:
|
||||
num_crops = [[0] for _ in images_list]
|
||||
num_crops = [0 for _ in images]
|
||||
|
||||
processed_images = []
|
||||
for images in images_list:
|
||||
for image in images:
|
||||
if do_resize:
|
||||
height, width = size["height"], size["width"]
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
@ -31,11 +30,8 @@ from ...image_processing_utils_fast import (
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
@ -103,52 +99,9 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
def pan_and_scan_batched(
|
||||
self,
|
||||
images: ImageInput,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_nested_list_of_images(images)
|
||||
|
||||
def _prepare_input_images(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_convert_rgb: bool = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> List["torch.Tensor"]:
|
||||
"""
|
||||
Prepare the input images for processing.
|
||||
"""
|
||||
batch_images = self._prepare_images_structure(images)
|
||||
process_image_fn = partial(
|
||||
self._process_image,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
# todo: yoni - check if we can parallelize this efficiently
|
||||
batch_processed_images = []
|
||||
for image_list in batch_images:
|
||||
processed_images = []
|
||||
for image in image_list:
|
||||
processed_images.append(process_image_fn(image))
|
||||
batch_processed_images.append(processed_images)
|
||||
|
||||
return batch_processed_images
|
||||
|
||||
def pan_and_scan(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
images: "torch.Tensor",
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
@ -167,7 +120,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
"""
|
||||
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
height, width = images.shape[-2:]
|
||||
|
||||
# Square or landscape image.
|
||||
if width >= height:
|
||||
@ -210,7 +163,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
||||
|
||||
return [
|
||||
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
images[..., pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
||||
]
|
||||
|
||||
@ -222,18 +175,14 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
):
|
||||
pas_images_list = []
|
||||
num_crops = []
|
||||
for image in images:
|
||||
pas_images = self.pan_and_scan(
|
||||
image=image,
|
||||
pas_images = self.pan_and_scan_batched(
|
||||
images=images,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
)
|
||||
pas_images_list.extend([image] + pas_images)
|
||||
num_crops.append(len(pas_images))
|
||||
return pas_images_list, num_crops
|
||||
num_crops = [len(pas_images) for _ in images]
|
||||
return pas_images, num_crops
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
@ -274,46 +223,66 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
batch_num_crops = []
|
||||
|
||||
for images_list in images:
|
||||
# Group images by size for batched processing
|
||||
processed_images_grouped = {}
|
||||
num_crops_grouped = {}
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
for shape_images, stacked_images in grouped_images.items():
|
||||
if do_pan_and_scan:
|
||||
images_list, num_crops = self._process_images_for_pan_and_scan(
|
||||
images=images_list,
|
||||
pas_images, num_crops = self._process_images_for_pan_and_scan(
|
||||
images=stacked_images,
|
||||
do_pan_and_scan=do_pan_and_scan,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
)
|
||||
else:
|
||||
num_crops = [[0] for _ in images_list]
|
||||
|
||||
# Group images by size for batched processing
|
||||
# Add the thumbnails to the image patches
|
||||
stacked_images = [stacked_images] + pas_images
|
||||
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.extend(processed_image_patches)
|
||||
batch_num_crops.extend(num_crops)
|
||||
# Transpose to have the thumbnails with their corresponding patches
|
||||
stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous()
|
||||
else:
|
||||
num_crops = [0 for _ in stacked_images]
|
||||
|
||||
if do_resize:
|
||||
stacked_images = self.resize(
|
||||
image=stacked_images,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
num_crops_grouped[shape_images] = num_crops
|
||||
processed_images_grouped[shape_images] = stacked_images
|
||||
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
# If pan and scan is enabled, we need to flatten the list of images
|
||||
if do_pan_and_scan:
|
||||
resized_images = [image for images_list in resized_images for image in images_list]
|
||||
num_crops = reorder_images(num_crops_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors
|
||||
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,7 +113,8 @@ class Gemma3Processor(ProcessorMixin):
|
||||
)
|
||||
|
||||
# Replace image tokens by the full expanded sequence
|
||||
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
||||
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
||||
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
||||
|
||||
@ -139,7 +140,7 @@ class Gemma3Processor(ProcessorMixin):
|
||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
||||
|
||||
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
||||
array_ids = np.array(text_inputs["input_ids"])
|
||||
array_ids = text_inputs["input_ids"]
|
||||
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
||||
|
@ -189,6 +189,13 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
expected_output_image_shape = (9, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched unbalanced, 9 images because we have base image + 2 crops per each item
|
||||
encoded_images = image_processing(
|
||||
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt"
|
||||
).pixel_values
|
||||
expected_output_image_shape = (9, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
@ -250,3 +257,37 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence_batched_pas(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
crop_config = {
|
||||
"do_pan_and_scan": True,
|
||||
"pan_and_scan_max_num_crops": 448,
|
||||
"pan_and_scan_min_crop_size": 32,
|
||||
"pan_and_scan_min_ratio_to_activate": 0.3,
|
||||
}
|
||||
image_processor_dict = self.image_processor_dict
|
||||
image_processor_dict.update(crop_config)
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
image_processor_slow = self.image_processing_class(**image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
@ -395,7 +395,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_batch(self):
|
||||
@ -467,7 +467,56 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background. It looks like the cow is enjoying the beach'] # fmt: skip
|
||||
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_batch_crops(self):
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
crop_config = {
|
||||
"images_kwargs": {
|
||||
"do_pan_and_scan": True,
|
||||
"pan_and_scan_max_num_crops": 448,
|
||||
"pan_and_scan_min_crop_size": 32,
|
||||
"pan_and_scan_min_ratio_to_activate": 0.3,
|
||||
}
|
||||
}
|
||||
messages_2 = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||
},
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Are these images identical?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.messages, messages_2],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
**crop_config,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
|
||||
EXPECTED_TEXTS = [
|
||||
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background. It looks like the cow is enjoying the beach",
|
||||
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nWhile they all feature a brown cow in the foreground and a similar background (including the stop signs and",
|
||||
] # fmt: skip
|
||||
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user