fix multi-image case for llava-onevision (#38084)

* _get_padding_size module

* do not patchify images when processing multi image

* modify llava onevision image processor fast

* tensor to list of tensors

* backward compat

* reuse pad_to_square in llave & some clarification

* add to doc

* fix: consider no image cases (text only or video)

* add integration test

* style & repo_consistency
This commit is contained in:
youngrok cha 2025-05-21 18:50:46 +09:00 committed by GitHub
parent a21f11fca2
commit 101b3fa4ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 620 additions and 93 deletions

View File

@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
### Multi image inference
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
```python
import requests

View File

@ -364,19 +364,23 @@ class AriaImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -748,19 +748,23 @@ class AriaImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -424,19 +424,23 @@ class LlavaNextImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -141,19 +141,23 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
padded_image = F.pad(image, padding=padding)
return padded_image

View File

@ -315,6 +315,14 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
return resized_image
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._get_padding_size
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
@ -322,13 +330,10 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image
@ -437,6 +442,85 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
return pixel_values
# Copied from transformers.models.llava.image_processing_llava.LlavaImageProcessor.pad_to_square
def pad_to_square(
self,
image: np.ndarray,
background_color: Union[int, Tuple[int, int, int]] = 0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
image = (
to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
)
return image
def _preprocess(
self,
images: ImageInput,
@ -595,6 +679,17 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
do_pad = do_pad if do_pad is not None else self.do_pad
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
images = make_flat_list_of_images(images)
if not valid_images(images):
@ -630,25 +725,34 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
size_tuple = (
(size["height"], size["width"])
if "height" in size and "width" in size
else (size["shortest_edge"], size["shortest_edge"])
)
new_images = []
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
for image in images:
# convert image into a list of patches
# we intentionally use the same data format as the input data format
size_tuple = (
(size["height"], size["width"])
if "height" in size and "width" in size
else (size["shortest_edge"], size["shortest_edge"])
)
image_patches = self.get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
for i, image in enumerate(images):
if need_patching[i]:
# convert image into a list of patches
# we intentionally use the same data format as the input data format
image_patches = self.get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
else:
padded_image = self.pad_to_square(
image=image,
background_color=tuple(int(x * 255) for x in self.image_mean),
input_data_format=input_data_format,
)
image_patches = [padded_image]
# preprocess patches
pixel_values = self._preprocess(
@ -671,7 +775,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
processed_images = self._pad_for_batching(new_images)
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)

View File

@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
@ -89,6 +89,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
kwargs["batch_num_images"] = batch_num_images
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
@ -137,19 +146,23 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
padded_image = F.pad(image, padding=padding)
return padded_image
@ -234,10 +247,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
batch_num_images: List[int],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
@ -252,14 +270,20 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
else:
patch_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
for i, image in enumerate(images):
if need_patching[i]:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
else:
padded_image = self.pad_to_square(
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
)
image_patches = [padded_image]
# Group images by size for batched processing
processed_image_patches_grouped = {}
@ -289,8 +313,52 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
__all__ = ["LlavaOnevisionImageProcessorFast"]

View File

@ -419,8 +419,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
batch_num_images: Optional[torch.LongTensor] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -430,34 +431,34 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`, *optional*):
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*):
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
if batch_num_images is None:
# treat this as a single-image case for backward compatibility
need_patching = [True] * len(image_sizes)
else:
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
if should_patch
else 1
for imsize, should_patch in zip(image_sizes, need_patching)
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
@ -500,6 +501,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -520,6 +522,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -558,6 +562,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
)
image_features, feature_lens = self.pack_image_features(
image_features,
@ -749,6 +754,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -771,6 +777,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -832,6 +840,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,

View File

@ -28,18 +28,59 @@ from transformers.models.llava_next_video.modeling_llava_next_video import (
LlavaNextVideoModelOutputWithPast,
LlavaNextVideoPreTrainedModel,
get_anyres_image_grid_shape,
image_size_to_num_patches,
unpad_image,
)
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils import (
TensorType,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
is_torchvision_available,
is_torchvision_v2_available,
logging,
)
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
"""
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
@ -56,6 +97,147 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
model_input_names = ["pixel_values_videos"]
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
kwargs["batch_num_images"] = batch_num_images
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
image_grid_pinpoints: List[List[int]],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
batch_num_images: List[int],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
patch_size = crop_size.height
elif size and size.height:
patch_size = size.height
else:
patch_size = size.shortest_edge
for i, image in enumerate(images):
if need_patching[i]:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
else:
padded_image = self.pad_to_square(
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
)
image_patches = [padded_image]
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)
class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
pass
@ -154,6 +336,76 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
image_features = image_features.view(batch_frames, -1, dim)
return image_features
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
batch_num_images: Optional[torch.LongTensor] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
if batch_num_images is None:
# treat this as a single-image case for backward compatibility
need_patching = [True] * len(image_sizes)
else:
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
if should_patch
else 1
for imsize, should_patch in zip(image_sizes, need_patching)
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def get_video_features(
self,
pixel_values: torch.FloatTensor,
@ -214,6 +466,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -234,6 +487,8 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -272,6 +527,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
)
image_features, feature_lens = self.pack_image_features(
image_features,
@ -355,6 +611,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -377,6 +634,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -438,6 +697,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,

View File

@ -170,12 +170,15 @@ class LlavaOnevisionProcessor(ProcessorMixin):
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
batch_num_images = iter(image_inputs["batch_num_images"])
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(
to_numpy_array(image_inputs["pixel_values"][0][0]),
channel_dim=output_kwargs["images_kwargs"].get("data_format"),
)
text, num_image_tokens = self._expand_image_tokens(text, image_sizes, height, width, self.image_token)
text, num_image_tokens = self._expand_image_tokens(
text, image_sizes, height, width, self.image_token, batch_num_images
)
if videos is not None:
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
@ -205,23 +208,29 @@ class LlavaOnevisionProcessor(ProcessorMixin):
height: int,
width: int,
special_token: str,
num_frames: int = 1,
batch_num_images: Iterable[int],
):
prompt_strings = []
max_num_vision_tokens = 0
for sample in text:
if special_token in sample:
is_multi_image = next(batch_num_images) != 1
else:
is_multi_image = False
while special_token in sample:
image_size_list = next(image_sizes)
original_size = image_size_list[0] if num_frames != 1 else image_size_list
if not isinstance(original_size, (list, tuple)):
# cast to list to avoid numerical precision errors when calculating unpadding
original_size = original_size.tolist()
orig_height, orig_width = original_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if is_multi_image:
num_image_tokens = self.num_image_tokens + 1 # one for image_newline
else:
original_size = next(image_sizes)
if not isinstance(original_size, (list, tuple)):
# cast to list to avoid numerical precision errors when calculating unpadding
original_size = original_size.tolist()
orig_height, orig_width = original_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
max_num_vision_tokens = max(max_num_vision_tokens, num_image_tokens)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens * num_frames, 1)
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
text = [sample.replace("<placeholder>", special_token) for sample in prompt_strings]
return text, max_num_vision_tokens

View File

@ -202,7 +202,7 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
image_inputs_nested = [[image_input] for image_input in image_inputs]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
@ -210,6 +210,39 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
# Image processor should return same pixel values, independently of input format
self.assertTrue((encoded_images_nested == encoded_images).all())
def test_multi_images(self):
length = 384
scale_single, scale_multi = 2, 3
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
image_processor_dict["size"] = {"height": length, "width": length} # patch size
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**image_processor_dict)
# Test batched as a nested list of images, where each sublist is one batch
len_image_1 = length * scale_single
image_inputs_1 = prepare_image_inputs(
batch_size=1,
min_resolution=0, # not used
max_resolution=len_image_1,
num_channels=3,
equal_resolution=True,
)
len_image_2 = length * scale_multi
image_inputs_2 = prepare_image_inputs(
batch_size=7,
min_resolution=0, # not used
max_resolution=len_image_2,
num_channels=3,
equal_resolution=True,
)
image_inputs = [image_inputs_1, image_inputs_2]
# Only single image should be patchified
expected_num_patches = scale_single**2 + 1 # +1 for base image patch
expected_output_image_shape = (8, expected_num_patches, 3, length, length)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip(
reason="LlavaOnevisionImageProcessorFast doesn't compile (infinitely) when using class transforms"
) # FIXME yoni

View File

@ -460,6 +460,33 @@ class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_multi_image_nested(self):
# related to (#34585)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
torch_dtype="float16",
device_map=torch_device,
)
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = (
"user\n<image><image>\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n"
)
images_nested = [[self.image, image]]
inputs = self.processor(text=prompt, images=images_nested, return_tensors="pt").to(torch_device, torch.float16)
# verify generation
output = model.generate(**inputs, max_new_tokens=40)
EXPECTED_DECODED_TEXT = "user\n\nWhat is the difference between these images?\nassistant\nThe first image is a radar chart showing the performance of different models in a specific task, while the second image is a street scene with a stop sign in the foreground." # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_multi_video(self):

View File

@ -233,7 +233,7 @@ class ImageProcessingTestMixin:
avg_time = sum(sorted(all_times[:3])) / 3.0
return avg_time
dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8)
dummy_images = [torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) for _ in range(4)]
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)