diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index e265177590b..14d5f6508ad 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True)) ### Multi image inference -LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it: +LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it: ```python import requests diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 44c6d40a4c6..54a2ec9488c 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -364,19 +364,23 @@ class AriaImageProcessor(BaseImageProcessor): return resized_image + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x) + def _pad_for_patching( self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension ) -> np.array: """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x))) + padded_image = self.pad(image, padding=padding) return padded_image diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 738b269b0bb..5afc05e9159 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -748,19 +748,23 @@ class AriaImageProcessor(BaseImageProcessor): return resized_image + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x) + def _pad_for_patching( self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension ) -> np.array: """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x))) + padded_image = self.pad(image, padding=padding) return padded_image diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index bf8920e955a..06601b45c52 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -424,19 +424,23 @@ class LlavaNextImageProcessor(BaseImageProcessor): return resized_image + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x) + def _pad_for_patching( self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension ) -> np.array: """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x))) + padded_image = self.pad(image, padding=padding) return padded_image diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index 8d4b3c48ba9..ac90290cef4 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -141,19 +141,23 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast): return resized_image + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return [paste_x, paste_y, paste_x + r_x, paste_y + r_y] + def _pad_for_patching( self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension ) -> "torch.Tensor": """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y]) + padded_image = F.pad(image, padding=padding) return padded_image diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index a664cfa7b64..6a471d712a9 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -315,6 +315,14 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): return resized_image + # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._get_padding_size + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x) + # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching def _pad_for_patching( self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension @@ -322,13 +330,10 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x))) + padded_image = self.pad(image, padding=padding) return padded_image @@ -437,6 +442,85 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): return pixel_values + # Copied from transformers.models.llava.image_processing_llava.LlavaImageProcessor.pad_to_square + def pad_to_square( + self, + image: np.ndarray, + background_color: Union[int, Tuple[int, int, int]] = 0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Pads an image to a square based on the longest edge. + + Args: + image (`np.ndarray`): + The image to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + """ + height, width = get_image_size(image, input_data_format) + num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1] + + if height == width: + image = ( + to_channel_dimension_format(image, data_format, input_data_format) + if data_format is not None + else image + ) + return image + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + if input_data_format == ChannelDimension.FIRST: + result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype) + for i, color in enumerate(background_color): + result[i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + result[:, start : start + height, :] = image + else: + start = (max_dim - width) // 2 + result[:, :, start : start + width] = image + else: + result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype) + for i, color in enumerate(background_color): + result[:, :, i] = color + if width > height: + start = (max_dim - height) // 2 + result[start : start + height, :, :] = image + else: + start = (max_dim - width) // 2 + result[:, start : start + width, :] = image + + image = ( + to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result + ) + return image + def _preprocess( self, images: ImageInput, @@ -595,6 +679,17 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)): + # if the first element is a list, we assume that all elements are lists + batch_num_images = [len(x) for x in images] + elif isinstance(images, (tuple, list)): + # treat this as a single-image case for backward compatibility + batch_num_images = [1] * len(images) + else: + batch_num_images = [1] + # only single image patching is supported + need_patching = [n == 1 for n in batch_num_images for _ in range(n)] + images = make_flat_list_of_images(images) if not valid_images(images): @@ -630,25 +725,34 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) + size_tuple = ( + (size["height"], size["width"]) + if "height" in size and "width" in size + else (size["shortest_edge"], size["shortest_edge"]) + ) + new_images = [] image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images] - for image in images: - # convert image into a list of patches - # we intentionally use the same data format as the input data format - size_tuple = ( - (size["height"], size["width"]) - if "height" in size and "width" in size - else (size["shortest_edge"], size["shortest_edge"]) - ) - image_patches = self.get_image_patches( - image, - image_grid_pinpoints, - size=size_tuple, - patch_size=size_tuple[0], - resample=resample, - data_format=input_data_format, - input_data_format=input_data_format, - ) + for i, image in enumerate(images): + if need_patching[i]: + # convert image into a list of patches + # we intentionally use the same data format as the input data format + image_patches = self.get_image_patches( + image, + image_grid_pinpoints, + size=size_tuple, + patch_size=size_tuple[0], + resample=resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + else: + padded_image = self.pad_to_square( + image=image, + background_color=tuple(int(x * 255) for x in self.image_mean), + input_data_format=input_data_format, + ) + image_patches = [padded_image] # preprocess patches pixel_values = self._preprocess( @@ -671,7 +775,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): processed_images = self._pad_for_batching(new_images) return BatchFeature( - data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors + data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images}, + tensor_type=return_tensors, ) diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index dc7a324c441..a29631fcb6a 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch @@ -89,6 +89,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): @auto_docstring def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature: + if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)): + # if the first element is a list, we assume that all elements are lists + batch_num_images = [len(x) for x in images] + elif isinstance(images, (tuple, list)): + # treat this as a single-image case for backward compatibility + batch_num_images = [1] * len(images) + else: + batch_num_images = [1] + kwargs["batch_num_images"] = batch_num_images return super().preprocess(images, **kwargs) def _prepare_images_structure( @@ -137,19 +146,23 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): return resized_image + def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple): + original_height, original_width = original_resolution + target_height, target_width = target_resolution + paste_x, r_x = divmod(target_width - original_width, 2) + paste_y, r_y = divmod(target_height - original_height, 2) + return [paste_x, paste_y, paste_x + r_x, paste_y + r_y] + def _pad_for_patching( self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension ) -> "torch.Tensor": """ Pad an image to a target resolution while maintaining aspect ratio. """ - target_height, target_width = target_resolution - new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + new_resolution = get_patch_output_size(image, target_resolution, input_data_format) + padding = self._get_padding_size(new_resolution, target_resolution) - paste_x, r_x = divmod(target_width - new_width, 2) - paste_y, r_y = divmod(target_height - new_height, 2) - - padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y]) + padded_image = F.pad(image, padding=padding) return padded_image @@ -234,10 +247,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): image_mean: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]], do_pad: bool, + batch_num_images: List[int], return_tensors: Optional[Union[str, TensorType]], ) -> BatchFeature: processed_images = [] image_sizes = [] + + # only single image patching is supported + need_patching = [n == 1 for n in batch_num_images for _ in range(n)] + # Determine the size tuple if size and size.height and size.width: size_tuple = (size.height, size.width) @@ -252,14 +270,20 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): else: patch_size = size.shortest_edge - for image in images: - image_patches = self._get_image_patches( - image, - image_grid_pinpoints, - size=size_tuple, - patch_size=patch_size, - interpolation=interpolation, - ) + for i, image in enumerate(images): + if need_patching[i]: + image_patches = self._get_image_patches( + image, + image_grid_pinpoints, + size=size_tuple, + patch_size=patch_size, + interpolation=interpolation, + ) + else: + padded_image = self.pad_to_square( + images=image, background_color=tuple(int(x * 255) for x in self.image_mean) + ) + image_patches = [padded_image] # Group images by size for batched processing processed_image_patches_grouped = {} @@ -289,8 +313,52 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): processed_images = self._pad_for_batching(processed_images) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( - data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors + data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images}, + tensor_type=return_tensors, ) + # Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, Tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + height, width = get_image_size(images, ChannelDimension.FIRST) + + if height == width: + return images + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + max_dim = max(height, width) + paste_x_left = (max_dim - width) // 2 + paste_y_left = (max_dim - height) // 2 + paste_x_right = max_dim - width - paste_x_left + paste_y_right = max_dim - height - paste_y_left + padded_images = F.pad( + images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color + ) + + return padded_images + __all__ = ["LlavaOnevisionImageProcessorFast"] diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index be600e8a96e..1a60c092ed9 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -419,8 +419,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Optional[Union[int, List[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + batch_num_images: Optional[torch.LongTensor] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -430,34 +431,34 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`, *optional*): + vision_feature_layer (`Union[int, List[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`, *optional*): + vision_feature_select_strategy (`str`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - # ! infer image_num_patches from image_sizes + if batch_num_images is None: + # treat this as a single-image case for backward compatibility + need_patching = [True] * len(image_sizes) + else: + need_patching = [n == 1 for n in batch_num_images for _ in range(n)] image_num_patches = [ image_size_to_num_patches( image_size=imsize, grid_pinpoints=self.config.image_grid_pinpoints, patch_size=self.config.vision_config.image_size, ) - for imsize in image_sizes + if should_patch + else 1 + for imsize, should_patch in zip(image_sizes, need_patching) ] if pixel_values.dim() == 5: # stacked if input is (batch_size, num_patches, num_channels, height, width) @@ -500,6 +501,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, + batch_num_images: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -520,6 +522,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): If `"full"`, the full vision features are used. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -558,6 +562,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, + batch_num_images=batch_num_images, ) image_features, feature_lens = self.pack_image_features( image_features, @@ -749,6 +754,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, + batch_num_images: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -771,6 +777,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene If `"full"`, the full vision features are used. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored @@ -832,6 +840,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene vision_aspect_ratio=vision_aspect_ratio, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, + batch_num_images=batch_num_images, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 34e13b50462..f838c5f703f 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -28,18 +28,59 @@ from transformers.models.llava_next_video.modeling_llava_next_video import ( LlavaNextVideoModelOutputWithPast, LlavaNextVideoPreTrainedModel, get_anyres_image_grid_shape, + image_size_to_num_patches, unpad_image, ) -from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size, +) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import ( + TensorType, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) + + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F logger = logging.get_logger(__name__) +class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + image_grid_pinpoints (`List[List[int]]`, *optional*): + A list of possible resolutions to use for processing high resolution images. The best resolution is selected + based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` + method. + do_pad (`bool`, *optional*): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + """ + + image_grid_pinpoints: Optional[List[List[int]]] + do_pad: Optional[bool] + + class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast): resample = PILImageResampling.BICUBIC image_mean = OPENAI_CLIP_MEAN @@ -56,6 +97,147 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast): image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip model_input_names = ["pixel_values_videos"] + # Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, Tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + height, width = get_image_size(images, ChannelDimension.FIRST) + + if height == width: + return images + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + max_dim = max(height, width) + paste_x_left = (max_dim - width) // 2 + paste_y_left = (max_dim - height) // 2 + paste_x_right = max_dim - width - paste_x_left + paste_y_right = max_dim - height - paste_y_left + padded_images = F.pad( + images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color + ) + + return padded_images + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature: + if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)): + # if the first element is a list, we assume that all elements are lists + batch_num_images = [len(x) for x in images] + elif isinstance(images, (tuple, list)): + # treat this as a single-image case for backward compatibility + batch_num_images = [1] * len(images) + else: + batch_num_images = [1] + kwargs["batch_num_images"] = batch_num_images + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + size: SizeDict, + image_grid_pinpoints: List[List[int]], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + do_pad: bool, + batch_num_images: List[int], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + processed_images = [] + image_sizes = [] + + # only single image patching is supported + need_patching = [n == 1 for n in batch_num_images for _ in range(n)] + + # Determine the size tuple + if size and size.height and size.width: + size_tuple = (size.height, size.width) + else: + size_tuple = (size.shortest_edge, size.shortest_edge) + + # Determine the patch size + if crop_size and crop_size.height: + patch_size = crop_size.height + elif size and size.height: + patch_size = size.height + else: + patch_size = size.shortest_edge + + for i, image in enumerate(images): + if need_patching[i]: + image_patches = self._get_image_patches( + image, + image_grid_pinpoints, + size=size_tuple, + patch_size=patch_size, + interpolation=interpolation, + ) + else: + padded_image = self.pad_to_square( + images=image, background_color=tuple(int(x * 255) for x in self.image_mean) + ) + image_patches = [padded_image] + + # Group images by size for batched processing + processed_image_patches_grouped = {} + grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches) + for shape, stacked_image_patches in grouped_image_patches.items(): + if do_resize: + stacked_image_patches = self.resize( + image=stacked_image_patches, + size=size, + interpolation=interpolation, + ) + if do_center_crop: + stacked_image_patches = self.center_crop(stacked_image_patches, crop_size) + # Fused rescale and normalize + stacked_image_patches = self.rescale_and_normalize( + stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.append(processed_image_patches) + image_sizes.append(get_image_size(image, ChannelDimension.FIRST)) + + if do_pad: + processed_images = self._pad_for_batching(processed_images) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images}, + tensor_type=return_tensors, + ) + class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast): pass @@ -154,6 +336,76 @@ class LlavaOnevisionModel(LlavaNextVideoModel): image_features = image_features.view(batch_frames, -1, dim) return image_features + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + batch_num_images: Optional[torch.LongTensor] = None, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + if batch_num_images is None: + # treat this as a single-image case for backward compatibility + need_patching = [True] * len(image_sizes) + else: + need_patching = [n == 1 for n in batch_num_images for _ in range(n)] + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + if should_patch + else 1 + for imsize, should_patch in zip(image_sizes, need_patching) + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_features.hidden_states[vision_feature_layer] + else: + hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + def get_video_features( self, pixel_values: torch.FloatTensor, @@ -214,6 +466,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel): vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, + batch_num_images: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -234,6 +487,8 @@ class LlavaOnevisionModel(LlavaNextVideoModel): If `"full"`, the full vision features are used. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -272,6 +527,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel): image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, + batch_num_images=batch_num_images, ) image_features, feature_lens = self.pack_image_features( image_features, @@ -355,6 +611,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, + batch_num_images: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -377,6 +634,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat If `"full"`, the full vision features are used. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". + batch_num_images (`torch.LongTensor`, *optional*): + Number of images in each sample. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored @@ -438,6 +697,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat vision_aspect_ratio=vision_aspect_ratio, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, + batch_num_images=batch_num_images, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 0c114e96e55..ca45ed63f39 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -170,12 +170,15 @@ class LlavaOnevisionProcessor(ProcessorMixin): if images is not None: image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + batch_num_images = iter(image_inputs["batch_num_images"]) image_sizes = iter(image_inputs["image_sizes"]) height, width = get_image_size( to_numpy_array(image_inputs["pixel_values"][0][0]), channel_dim=output_kwargs["images_kwargs"].get("data_format"), ) - text, num_image_tokens = self._expand_image_tokens(text, image_sizes, height, width, self.image_token) + text, num_image_tokens = self._expand_image_tokens( + text, image_sizes, height, width, self.image_token, batch_num_images + ) if videos is not None: video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) @@ -205,23 +208,29 @@ class LlavaOnevisionProcessor(ProcessorMixin): height: int, width: int, special_token: str, - num_frames: int = 1, + batch_num_images: Iterable[int], ): prompt_strings = [] max_num_vision_tokens = 0 for sample in text: + if special_token in sample: + is_multi_image = next(batch_num_images) != 1 + else: + is_multi_image = False while special_token in sample: - image_size_list = next(image_sizes) - original_size = image_size_list[0] if num_frames != 1 else image_size_list - if not isinstance(original_size, (list, tuple)): - # cast to list to avoid numerical precision errors when calculating unpadding - original_size = original_size.tolist() - orig_height, orig_width = original_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if is_multi_image: + num_image_tokens = self.num_image_tokens + 1 # one for image_newline + else: + original_size = next(image_sizes) + if not isinstance(original_size, (list, tuple)): + # cast to list to avoid numerical precision errors when calculating unpadding + original_size = original_size.tolist() + orig_height, orig_width = original_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) max_num_vision_tokens = max(max_num_vision_tokens, num_image_tokens) if self.vision_feature_select_strategy == "default": num_image_tokens -= 1 - sample = sample.replace(special_token, "" * num_image_tokens * num_frames, 1) + sample = sample.replace(special_token, "" * num_image_tokens, 1) prompt_strings.append(sample) text = [sample.replace("", special_token) for sample in prompt_strings] return text, max_num_vision_tokens diff --git a/tests/models/llava_onevision/test_image_processing_llava_onevision.py b/tests/models/llava_onevision/test_image_processing_llava_onevision.py index 285be5ecf81..4aba232c9df 100644 --- a/tests/models/llava_onevision/test_image_processing_llava_onevision.py +++ b/tests/models/llava_onevision/test_image_processing_llava_onevision.py @@ -202,7 +202,7 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) # Test batched as a nested list of images, where each sublist is one batch - image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + image_inputs_nested = [[image_input] for image_input in image_inputs] encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values expected_output_image_shape = (7, 1522, 3, 20, 20) self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) @@ -210,6 +210,39 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC # Image processor should return same pixel values, independently of input format self.assertTrue((encoded_images_nested == encoded_images).all()) + def test_multi_images(self): + length = 384 + scale_single, scale_multi = 2, 3 + image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() + image_processor_dict["size"] = {"height": length, "width": length} # patch size + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**image_processor_dict) + + # Test batched as a nested list of images, where each sublist is one batch + len_image_1 = length * scale_single + image_inputs_1 = prepare_image_inputs( + batch_size=1, + min_resolution=0, # not used + max_resolution=len_image_1, + num_channels=3, + equal_resolution=True, + ) + len_image_2 = length * scale_multi + image_inputs_2 = prepare_image_inputs( + batch_size=7, + min_resolution=0, # not used + max_resolution=len_image_2, + num_channels=3, + equal_resolution=True, + ) + image_inputs = [image_inputs_1, image_inputs_2] + + # Only single image should be patchified + expected_num_patches = scale_single**2 + 1 # +1 for base image patch + expected_output_image_shape = (8, expected_num_patches, 3, length, length) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + @unittest.skip( reason="LlavaOnevisionImageProcessorFast doesn't compile (infinitely) when using class transforms" ) # FIXME yoni diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index fba739b9956..53dc267d778 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -460,6 +460,33 @@ class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase): EXPECTED_DECODED_TEXT, ) + @slow + @require_bitsandbytes + def test_small_model_integration_test_multi_image_nested(self): + # related to (#34585) + model = LlavaOnevisionForConditionalGeneration.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + torch_dtype="float16", + device_map=torch_device, + ) + + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image = Image.open(requests.get(url, stream=True).raw) + prompt = ( + "user\n\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n" + ) + images_nested = [[self.image, image]] + inputs = self.processor(text=prompt, images=images_nested, return_tensors="pt").to(torch_device, torch.float16) + + # verify generation + output = model.generate(**inputs, max_new_tokens=40) + EXPECTED_DECODED_TEXT = "user\n\nWhat is the difference between these images?\nassistant\nThe first image is a radar chart showing the performance of different models in a specific task, while the second image is a street scene with a stop sign in the foreground." # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + @slow @require_bitsandbytes def test_small_model_integration_test_multi_video(self): diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index f70adea169c..12b9531b848 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -233,7 +233,7 @@ class ImageProcessingTestMixin: avg_time = sum(sorted(all_times[:3])) / 3.0 return avg_time - dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) + dummy_images = [torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) for _ in range(4)] image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)