mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fuyu: improve image processing (#27007)
* Fix Fuyu image scaling bug
It could produce negative padding and hence inference errors for certain
image sizes.
* initial rework commit
* add batching capabilities, refactor image processing
* add functional batching for a list of images and texts
* make args explicit
* Fuyu processing update (#27133)
* Add file headers
* Add file headers
* First pass - preprocess method with standard args
* First pass image processor rework
* Small tweaks
* More args and docstrings
* Tidying iterating over batch
* Tidying up
* Modify to have quick tests (for now)
* Fix up
* BatchFeature
* Passing tests
* Add tests for processor
* Sense check when patchifying
* Add some tests
* FuyuBatchFeature
* Post-process box coordinates
* Update to `size` in processor
* Remove unused and duplicate constants
* Store unpadded dims after resize
* Fix up
* Return FuyuBatchFeature
* Get unpadded sizes after resize
* Update exception
* Fix return
* Convert input `<box>` coordinates to model format.
* Post-process point coords, support multiple boxes/points in a single
sequence
* Replace constants
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Preprocess List[List[image]]
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update to Amy's latest state.
* post-processing returns a list of tensors
* Fix error when target_sizes is None
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Review comments
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Fix up
* Fix up
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
* Fix conflicts in fuyu_follow_up_image_processing (#27228)
fixing conflicts and updating on main
* Revert "Fix conflicts in fuyu_follow_up_image_processing" (#27232)
Revert "Fix conflicts in fuyu_follow_up_image_processing (#27228)"
This reverts commit acce10b6c6
.
---------
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
This commit is contained in:
parent
9b25c164bd
commit
8a312956fd
@ -112,17 +112,9 @@ class BatchFeature(UserDict):
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
||||
"""
|
||||
Convert the inner content to tensors.
|
||||
|
||||
Args:
|
||||
tensor_type (`str` or [`~utils.TensorType`], *optional*):
|
||||
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
|
||||
`None`, no modification is done.
|
||||
"""
|
||||
def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
||||
if tensor_type is None:
|
||||
return self
|
||||
return None, None
|
||||
|
||||
# Convert to TensorType
|
||||
if not isinstance(tensor_type, TensorType):
|
||||
@ -167,6 +159,21 @@ class BatchFeature(UserDict):
|
||||
return np.asarray(value, dtype=dtype)
|
||||
|
||||
is_tensor = is_numpy_array
|
||||
return is_tensor, as_tensor
|
||||
|
||||
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
||||
"""
|
||||
Convert the inner content to tensors.
|
||||
|
||||
Args:
|
||||
tensor_type (`str` or [`~utils.TensorType`], *optional*):
|
||||
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
|
||||
`None`, no modification is done.
|
||||
"""
|
||||
if tensor_type is None:
|
||||
return self
|
||||
|
||||
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
|
@ -1,27 +1,182 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Fuyu."""
|
||||
|
||||
import math
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import (
|
||||
normalize,
|
||||
pad,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
is_torch_device,
|
||||
is_torch_dtype,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from ...image_utils import to_numpy_array
|
||||
from ...utils import is_torch_available, is_vision_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def make_list_of_list_of_images(
|
||||
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
|
||||
) -> List[List[ImageInput]]:
|
||||
if is_valid_image(images):
|
||||
return [[images]]
|
||||
|
||||
if isinstance(images, list) and all(isinstance(image, list) for image in images):
|
||||
return images
|
||||
|
||||
if isinstance(images, list):
|
||||
return [make_list_of_images(image) for image in images]
|
||||
|
||||
raise ValueError("images must be a list of list of images or a list of images or an image.")
|
||||
|
||||
|
||||
class FuyuBatchFeature(BatchFeature):
|
||||
"""
|
||||
BatchFeature class for Fuyu image processor and processor.
|
||||
|
||||
The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
|
||||
"""
|
||||
|
||||
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
||||
"""
|
||||
Convert the inner content to tensors.
|
||||
|
||||
Args:
|
||||
tensor_type (`str` or [`~utils.TensorType`], *optional*):
|
||||
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
|
||||
`None`, no modification is done.
|
||||
"""
|
||||
if tensor_type is None:
|
||||
return self
|
||||
|
||||
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type)
|
||||
|
||||
def _convert_tensor(elem):
|
||||
if is_tensor(elem):
|
||||
return elem
|
||||
return as_tensor(elem)
|
||||
|
||||
def _safe_convert_tensor(elem):
|
||||
try:
|
||||
return _convert_tensor(elem)
|
||||
except: # noqa E722
|
||||
if key == "overflowing_values":
|
||||
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
|
||||
raise ValueError(
|
||||
"Unable to create tensor, you should probably activate padding "
|
||||
"with 'padding=True' to have batched tensors with the same length."
|
||||
)
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
if isinstance(value, list) and isinstance(value[0], list):
|
||||
# List[List[Any]] -> List[List[Tensor]]
|
||||
self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value]
|
||||
elif isinstance(value, list):
|
||||
# List[Any] -> List[Tensor]
|
||||
self[key] = [_safe_convert_tensor(elem) for elem in value]
|
||||
else:
|
||||
# Any -> Tensor
|
||||
self[key] = _safe_convert_tensor(value)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs) -> "BatchFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||
|
||||
Args:
|
||||
args (`Tuple`):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
kwargs (`Dict`, *optional*):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
# device should be always the first argument
|
||||
arg = args[0]
|
||||
if is_torch_dtype(arg):
|
||||
# The first argument is a dtype
|
||||
pass
|
||||
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||
device = arg
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
|
||||
def _to(elem):
|
||||
# check if v is a floating point
|
||||
if torch.is_floating_point(elem):
|
||||
# cast and send to device
|
||||
return elem.to(*args, **kwargs)
|
||||
if device is not None:
|
||||
return elem.to(device=device)
|
||||
|
||||
return elem
|
||||
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
if isinstance(v, list) and isinstance(v[0], list):
|
||||
# Data structure is a list of lists
|
||||
new_v = []
|
||||
for elems in v:
|
||||
new_v.append([_to(elem) for elem in elems])
|
||||
new_data[k] = new_v
|
||||
elif isinstance(v, list):
|
||||
# Data structure is a list
|
||||
new_data[k] = [_to(elem) for elem in v]
|
||||
else:
|
||||
new_data[k] = _to(v)
|
||||
self.data = new_data
|
||||
return self
|
||||
|
||||
|
||||
class FuyuImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
|
||||
@ -29,9 +184,9 @@ class FuyuImageProcessor(BaseImageProcessor):
|
||||
|
||||
- Processing Images:
|
||||
Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
|
||||
dimensions. The image output is always img_h ........................................... 1080 img_w
|
||||
........................................... 1920 Then, it patches up these images using the patchify_image
|
||||
function.
|
||||
dimensions. The image output is always img_h, img_w of (1080, 1920)
|
||||
|
||||
Then, it patches up these images using the patchify_image function.
|
||||
|
||||
- Creating Image Input IDs:
|
||||
For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
|
||||
@ -40,6 +195,32 @@ class FuyuImageProcessor(BaseImageProcessor):
|
||||
- Image Patch Indices:
|
||||
For each image patch, the code maintains an index where these patches should be inserted in a token stream.
|
||||
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image to `size`.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the image to `size`.
|
||||
padding_value (`float`, *optional*, defaults to 1.0):
|
||||
The value to pad the image with.
|
||||
padding_mode (`str`, *optional*, defaults to `"constant"`):
|
||||
The padding mode to use when padding the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float`, *optional*, defaults to 0.5):
|
||||
The mean to use when normalizing the image.
|
||||
image_std (`float`, *optional*, defaults to 0.5):
|
||||
The standard deviation to use when normalizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `1 / 255`):
|
||||
The factor to use when rescaling the image.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
|
||||
"""
|
||||
|
||||
model_input_names = [
|
||||
@ -51,204 +232,483 @@ class FuyuImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, target_height=1080, target_width=1920, padding_value=1.0, padding_mode: str = "constant", **kwargs
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_pad: bool = True,
|
||||
padding_value: float = 1.0,
|
||||
padding_mode: str = "constant",
|
||||
do_normalize: bool = True,
|
||||
image_mean: Union[float, List[float]] = 0.5,
|
||||
image_std: Union[float, List[float]] = 0.5,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
patch_size: Optional[Dict[str, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.target_width = target_width
|
||||
self.target_height = target_height
|
||||
self.do_resize = do_resize
|
||||
self.size = size if size is not None else {"height": 1080, "width": 1920}
|
||||
self.resample = resample
|
||||
self.do_pad = do_pad
|
||||
self.padding_value = padding_value
|
||||
self.padding_mode = padding_mode
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30}
|
||||
|
||||
def get_num_patches(self, img_h: int, img_w: int, patch_dim_h: int, patch_dim_w: int) -> int:
|
||||
"""Calculate number of patches required to encode an image."""
|
||||
if img_h % patch_dim_h != 0:
|
||||
raise ValueError(f"{img_h=} must be divisible by {patch_dim_h=}")
|
||||
if img_w % patch_dim_w != 0:
|
||||
raise ValueError(f"{img_w=} must be divisible by {patch_dim_w=}")
|
||||
|
||||
num_patches_per_dim_h = img_h // patch_dim_h
|
||||
num_patches_per_dim_w = img_w // patch_dim_w
|
||||
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
|
||||
|
||||
return num_patches
|
||||
|
||||
def patchify_image(self, image: "torch.Tensor", patch_dim_h: int, patch_dim_w: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert an image into a tensor of patches.
|
||||
|
||||
Args:
|
||||
image: Image to convert. Shape: [batch, channels, height, width]
|
||||
patch_dim_h: Height of each patch.
|
||||
patch_dim_w: Width of each patch.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
# TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
|
||||
# torch implementation is faster but does not handle non-squares
|
||||
|
||||
batch_size, channels, height, width = image.shape
|
||||
unfolded_along_height = image.unfold(2, patch_dim_h, patch_dim_h)
|
||||
patches = unfolded_along_height.unfold(3, patch_dim_w, patch_dim_w)
|
||||
|
||||
patches_reshaped = patches.contiguous().view(batch_size, channels, -1, patch_dim_h, patch_dim_w)
|
||||
|
||||
patches_final = patches_reshaped.permute(0, 2, 3, 4, 1).reshape(
|
||||
batch_size, -1, channels * patch_dim_h * patch_dim_w
|
||||
)
|
||||
|
||||
return patches_final
|
||||
|
||||
def process_images_for_model_input(
|
||||
def resize(
|
||||
self,
|
||||
image_input: "torch.Tensor",
|
||||
image_present: "torch.Tensor",
|
||||
image_unpadded_h: "torch.Tensor",
|
||||
image_unpadded_w: "torch.Tensor",
|
||||
image_patch_dim_h: int,
|
||||
image_patch_dim_w: int,
|
||||
image_placeholder_id: int,
|
||||
image_newline_id: int,
|
||||
variable_sized: bool,
|
||||
) -> dict:
|
||||
"""Process images for model input. In particular, variable-sized images are handled here.
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image_input: [batch_size, 1, c, h, w] tensor of images padded to model input size.
|
||||
image_present: [batch_size, 1] tensor of 1s and 0s indicating whether an image is present.
|
||||
image_unpadded_h: [batch_size, 1] tensor of unpadded image heights.
|
||||
image_unpadded_w: [batch_size, 1] tensor of unpadded image widths.
|
||||
image_patch_dim_h: The height of the image patches.
|
||||
image_patch_dim_w: The width of the image patches.
|
||||
image_placeholder_id: The id of the image placeholder token.
|
||||
image_newline_id: The id of the image newline token.
|
||||
variable_sized: Whether to process images as variable-sized.
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
# Only images that are present.
|
||||
images: List[List[torch.Tensor]] = []
|
||||
image_patches: List[List[torch.Tensor]] = []
|
||||
# Image input ids for every subsequence, including ones with no image present.
|
||||
image_input_ids: List[List[torch.Tensor]] = []
|
||||
for bi in range(image_input.shape[0]):
|
||||
images.append([])
|
||||
image_input_ids.append([])
|
||||
image_patches.append([])
|
||||
for si in range(image_input.shape[1]):
|
||||
if image_present[bi, si]:
|
||||
image = image_input[bi, si]
|
||||
if variable_sized:
|
||||
# The min() is required here due to floating point issues:
|
||||
# math.ceil(torch.tensor(300).cuda() / 30) == 11
|
||||
new_h = min(
|
||||
image.shape[1], math.ceil(image_unpadded_h[bi, si] / image_patch_dim_h) * image_patch_dim_h
|
||||
)
|
||||
new_w = min(
|
||||
image.shape[2], math.ceil(image_unpadded_w[bi, si] / image_patch_dim_w) * image_patch_dim_w
|
||||
)
|
||||
image = image[:, :new_h, :new_w]
|
||||
images[bi].append(image)
|
||||
num_patches = self.get_num_patches(
|
||||
img_h=image.shape[1],
|
||||
img_w=image.shape[2],
|
||||
patch_dim_h=image_patch_dim_h,
|
||||
patch_dim_w=image_patch_dim_w,
|
||||
)
|
||||
ids = torch.full([num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device)
|
||||
patches = self.patchify_image(
|
||||
image=image.unsqueeze(0), patch_dim_h=image_patch_dim_h, patch_dim_w=image_patch_dim_w
|
||||
).squeeze(0)
|
||||
if variable_sized:
|
||||
# Now terminate each line with |NEWLINE|.
|
||||
ids = ids.reshape(-1, new_w // image_patch_dim_w)
|
||||
ids = torch.cat(
|
||||
[
|
||||
ids,
|
||||
torch.full(
|
||||
[ids.shape[0], 1], image_newline_id, dtype=torch.int32, device=image_input.device
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
ids = ids.reshape(-1)
|
||||
image_input_ids[bi].append(ids)
|
||||
image_patches[bi].append(patches)
|
||||
else:
|
||||
image_input_ids[bi].append(torch.tensor([], dtype=torch.int32, device=image_input.device))
|
||||
image_height, image_width = get_image_size(image, input_data_format)
|
||||
target_height, target_width = size["height"], size["width"]
|
||||
|
||||
# Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
|
||||
# the stream.
|
||||
image_patch_indices_per_batch: List[List[torch.Tensor]] = []
|
||||
image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
|
||||
for bi in range(len(image_input_ids)):
|
||||
image_patch_indices_per_batch.append([])
|
||||
image_patch_indices_per_subsequence.append([])
|
||||
index_offset = 0
|
||||
for si in range(len(image_input_ids[bi])):
|
||||
# Indices of image patches.
|
||||
num_patches = torch.count_nonzero(image_input_ids[bi][si] == image_placeholder_id)
|
||||
indices = torch.arange(
|
||||
num_patches,
|
||||
dtype=image_input_ids[bi][si].dtype,
|
||||
device=image_input_ids[bi][si].device,
|
||||
)
|
||||
|
||||
# Place those indices in the image input ids token stream, with -1 representing non-index tokens.
|
||||
indices_in_stream_per_batch = torch.full_like(image_input_ids[bi][si], -1)
|
||||
indices_in_stream_per_subsequence = torch.full_like(image_input_ids[bi][si], -1)
|
||||
indices_in_stream_per_batch[
|
||||
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
|
||||
] = (indices + index_offset)
|
||||
indices_in_stream_per_subsequence[
|
||||
torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
|
||||
] = indices
|
||||
|
||||
image_patch_indices_per_batch[bi].append(indices_in_stream_per_batch)
|
||||
image_patch_indices_per_subsequence[bi].append(indices_in_stream_per_subsequence)
|
||||
index_offset += num_patches
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"image_input_ids": image_input_ids,
|
||||
"image_patches": image_patches,
|
||||
"image_patch_indices_per_batch": image_patch_indices_per_batch,
|
||||
"image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
|
||||
}
|
||||
|
||||
def _scale_to_target_aspect_ratio(self, image: np.ndarray) -> np.ndarray:
|
||||
image_height, image_width, _ = image.shape
|
||||
if image_width <= self.target_width and image_height <= self.target_height:
|
||||
if image_width <= target_width and image_height <= target_height:
|
||||
return image
|
||||
|
||||
height_scale_factor = self.target_height / image_height
|
||||
width_scale_factor = self.target_width / image_width
|
||||
height_scale_factor = target_height / image_height
|
||||
width_scale_factor = target_width / image_width
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
|
||||
new_height = int(image_height * optimal_scale_factor)
|
||||
new_width = int(image_width * optimal_scale_factor)
|
||||
|
||||
scaled_image = resize(image=image, size=(new_height, new_width))
|
||||
return np.array(scaled_image)
|
||||
scaled_image = resize(
|
||||
image=image,
|
||||
size=(new_height, new_width),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return scaled_image
|
||||
|
||||
def _pad_to_target_size(self, image: np.ndarray) -> np.ndarray:
|
||||
image_height, image_width, _ = image.shape
|
||||
def pad_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
mode: str = "constant",
|
||||
constant_values: float = 1.0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to pad.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The data format of the output image. If unset, the same format as the input image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image, input_data_format)
|
||||
target_height, target_width = size["height"], size["width"]
|
||||
padding_top = 0
|
||||
padding_left = 0
|
||||
padding_bottom = self.target_height - image_height
|
||||
padding_right = self.target_width - image_width
|
||||
|
||||
padding_bottom = target_height - image_height
|
||||
padding_right = target_width - image_width
|
||||
padded_image = pad(
|
||||
image,
|
||||
((padding_top, padding_bottom), (padding_left, padding_right)),
|
||||
mode=self.padding_mode,
|
||||
constant_values=self.padding_value,
|
||||
padding=((padding_top, padding_bottom), (padding_left, padding_right)),
|
||||
mode=mode,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
def apply_transformation(self, image: Union[np.ndarray, PIL.Image.Image]) -> np.ndarray:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = to_numpy_array(image)
|
||||
scaled_image = self._scale_to_target_aspect_ratio(image)
|
||||
padded_image = self._pad_to_target_size(scaled_image)
|
||||
normalized_padded_image = normalize(padded_image, 0.5, 0.5)
|
||||
return normalized_padded_image
|
||||
def preprocess(
|
||||
self,
|
||||
images,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
padding_value: Optional[float] = None,
|
||||
padding_mode: Optional[str] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[float] = None,
|
||||
image_std: Optional[float] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
patch_size: Optional[Dict[str, int]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
return_tensors: Optional[TensorType] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Utility function to preprocess the images and extract necessary information about original formats.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel
|
||||
values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image to `size`.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the image to `size`.
|
||||
padding_value (`float`, *optional*, defaults to `self.padding_value`):
|
||||
The value to pad the image with.
|
||||
padding_mode (`str`, *optional*, defaults to `self.padding_mode`):
|
||||
The padding mode to use when padding the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float`, *optional*, defaults to `self.image_mean`):
|
||||
The mean to use when normalizing the image.
|
||||
image_std (`float`, *optional*, defaults to `self.image_std`):
|
||||
The standard deviation to use when normalizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
The factor to use when rescaling the image.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format of the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
padding_value = padding_value if padding_value is not None else self.padding_value
|
||||
padding_mode = padding_mode if padding_mode is not None else self.padding_mode
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
|
||||
if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
|
||||
raise ValueError("Multiple images for a single sample are not yet supported.")
|
||||
|
||||
batch_images = make_list_of_list_of_images(images)
|
||||
|
||||
if do_resize and size is None:
|
||||
raise ValueError("Size must be specified if do_resize is True.")
|
||||
|
||||
if do_rescale and rescale_factor is None:
|
||||
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||
|
||||
if do_normalize and image_mean is None or image_std is None:
|
||||
raise ValueError("image_mean and image_std must be specified if do_normalize is True.")
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]
|
||||
|
||||
if is_scaled_image(batch_images[0][0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(batch_images[0][0])
|
||||
|
||||
original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||
|
||||
if do_resize:
|
||||
batch_images = [
|
||||
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
|
||||
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
|
||||
|
||||
# scale_h is the same as scale_w
|
||||
image_scale_factors = [
|
||||
[resized_size[0] / original_size[0]]
|
||||
for original_size, resized_size in zip(original_image_sizes, image_sizes)
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
batch_images = [
|
||||
[
|
||||
self.pad_image(
|
||||
image,
|
||||
size=size,
|
||||
mode=padding_mode,
|
||||
constant_values=padding_value,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
batch_images = [
|
||||
[self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
batch_images = [
|
||||
[
|
||||
self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
if data_format is not None:
|
||||
batch_images = [
|
||||
[to_channel_dimension_format(image, data_format, input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
data = {
|
||||
"images": batch_images,
|
||||
"image_unpadded_heights": image_unpadded_heights,
|
||||
"image_unpadded_widths": image_unpadded_widths,
|
||||
"image_scale_factors": image_scale_factors,
|
||||
}
|
||||
return FuyuBatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int:
|
||||
"""
|
||||
Calculate number of patches required to encode an image.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
Height of the image.
|
||||
image_width (`int`):
|
||||
Width of the image.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
|
||||
"""
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
|
||||
if image_height % patch_height != 0:
|
||||
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
|
||||
if image_width % patch_width != 0:
|
||||
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
|
||||
|
||||
num_patches_per_dim_h = image_height // patch_height
|
||||
num_patches_per_dim_w = image_width // patch_width
|
||||
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
|
||||
return num_patches
|
||||
|
||||
def patchify_image(self, image: "torch.Tensor", patch_size: Optional[Dict[str, int]] = None) -> "torch.Tensor":
|
||||
"""
|
||||
Convert an image into a tensor of patches.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to convert. Shape: [batch, channels, height, width]
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_height, patch_width = patch_size["height"], patch_size["width"]
|
||||
|
||||
# TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
|
||||
# torch implementation is faster but does not handle non-squares
|
||||
|
||||
batch_size, channels, _, _ = image.shape
|
||||
unfolded_along_height = image.unfold(2, patch_height, patch_height)
|
||||
patches = unfolded_along_height.unfold(3, patch_width, patch_width)
|
||||
patches = patches.contiguous()
|
||||
patches = patches.view(batch_size, channels, -1, patch_height, patch_width)
|
||||
patches = patches.permute(0, 2, 3, 4, 1)
|
||||
patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width)
|
||||
return patches
|
||||
|
||||
def preprocess_with_tokenizer_info(
|
||||
self,
|
||||
image_input: "torch.Tensor",
|
||||
image_present: "torch.Tensor",
|
||||
image_unpadded_h: "torch.Tensor",
|
||||
image_unpadded_w: "torch.Tensor",
|
||||
image_placeholder_id: int,
|
||||
image_newline_id: int,
|
||||
variable_sized: bool,
|
||||
patch_size: Optional[Dict[str, int]] = None,
|
||||
) -> FuyuBatchFeature:
|
||||
"""Process images for model input. In particular, variable-sized images are handled here.
|
||||
|
||||
Args:
|
||||
image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]):
|
||||
Tensor of images padded to model input size.
|
||||
image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]):
|
||||
Tensor of 1s and 0s indicating whether an image is present.
|
||||
image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]):
|
||||
Tensor of unpadded image heights.
|
||||
image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]):
|
||||
Tensor of unpadded image widths.
|
||||
image_placeholder_id (int):
|
||||
The id of the image placeholder token. Comes from an associated tokenizer.
|
||||
image_newline_id (int):
|
||||
The id of the image newline token. Comes from an associated tokenizer.
|
||||
variable_sized (bool):
|
||||
Whether to process images as variable-sized.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Size of the patches.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_height, patch_width = patch_size["height"], patch_size["width"]
|
||||
|
||||
# Only images that are present.
|
||||
images: List[List[torch.Tensor]] = []
|
||||
batch_image_patches: List[List[torch.Tensor]] = []
|
||||
# Image input ids for every subsequence, including ones with no image present.
|
||||
batch_image_input_ids: List[List[torch.Tensor]] = []
|
||||
for batch_index in range(image_input.shape[0]):
|
||||
image_input_ids = []
|
||||
image_patches = []
|
||||
for subseq_index in range(image_input.shape[1]):
|
||||
if image_present[batch_index, subseq_index]:
|
||||
image = image_input[batch_index, subseq_index]
|
||||
image_height, image_width = image.shape[1], image.shape[2]
|
||||
if variable_sized:
|
||||
# The min() is required here due to floating point issues:
|
||||
# math.ceil(torch.tensor(300).cuda() / 30) == 11
|
||||
new_h = min(
|
||||
image_height,
|
||||
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
|
||||
)
|
||||
new_w = min(
|
||||
image_width,
|
||||
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
|
||||
)
|
||||
image = image[:, :new_h, :new_w]
|
||||
image_height, image_width = new_h, new_w
|
||||
|
||||
num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
|
||||
tensor_of_image_ids = torch.full(
|
||||
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
|
||||
)
|
||||
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
|
||||
assert num_patches == patches.shape[0]
|
||||
|
||||
if variable_sized:
|
||||
# Now terminate each line with |NEWLINE|.
|
||||
tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
|
||||
newline_ids = torch.full(
|
||||
[tensor_of_image_ids.shape[0], 1],
|
||||
image_newline_id,
|
||||
dtype=torch.int32,
|
||||
device=image_input.device,
|
||||
)
|
||||
tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
|
||||
tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
|
||||
|
||||
images.append([image])
|
||||
image_input_ids.append(tensor_of_image_ids)
|
||||
image_patches.append(patches)
|
||||
else:
|
||||
image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
|
||||
|
||||
batch_image_input_ids.append(image_input_ids)
|
||||
batch_image_patches.append(image_patches)
|
||||
|
||||
# Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
|
||||
# the stream.
|
||||
image_patch_indices_per_batch: List[List[torch.Tensor]] = []
|
||||
image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
|
||||
|
||||
for sample_image_input_ids in batch_image_input_ids:
|
||||
index_offset = 0
|
||||
per_batch_indices = []
|
||||
per_subsequence_indices = []
|
||||
for subseq_image_input_ids in sample_image_input_ids:
|
||||
# Indices of image patches.
|
||||
patches_mask = subseq_image_input_ids == image_placeholder_id
|
||||
num_patches = torch.count_nonzero(patches_mask)
|
||||
indices = torch.arange(
|
||||
num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device
|
||||
)
|
||||
|
||||
# Place those indices in the image input ids token stream, with -1 representing non-index tokens.
|
||||
indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1)
|
||||
indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1)
|
||||
patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0]
|
||||
|
||||
indices_in_stream_per_batch[patches_inds] = indices + index_offset
|
||||
indices_in_stream_per_subsequence[patches_inds] = indices
|
||||
|
||||
per_batch_indices.append(indices_in_stream_per_batch)
|
||||
per_subsequence_indices.append(indices_in_stream_per_subsequence)
|
||||
index_offset += num_patches
|
||||
|
||||
image_patch_indices_per_batch.append(per_batch_indices)
|
||||
image_patch_indices_per_subsequence.append(per_subsequence_indices)
|
||||
|
||||
return FuyuBatchFeature(
|
||||
data={
|
||||
"images": images,
|
||||
"image_input_ids": batch_image_input_ids,
|
||||
"image_patches": batch_image_patches,
|
||||
"image_patch_indices_per_batch": image_patch_indices_per_batch,
|
||||
"image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
|
||||
}
|
||||
)
|
||||
|
@ -257,8 +257,10 @@ class FuyuForCausalLM(FuyuPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device)
|
||||
patch_embeddings = [
|
||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
|
||||
for patch in image_patches
|
||||
]
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
|
@ -1,45 +1,50 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Image/Text processor class for GIT
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
to_numpy_array,
|
||||
)
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import is_torch_available, is_vision_available, logging
|
||||
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
|
||||
from ...utils import TensorType, is_torch_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available() and is_vision_available():
|
||||
from .image_processing_fuyu import FuyuImageProcessor
|
||||
if is_torch_available():
|
||||
from .image_processing_fuyu import FuyuBatchFeature
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
BBOX_OPEN_STRING = "<0x00>" # <bbox>
|
||||
BBOX_CLOSE_STRING = "<0x01>" # </bbox>
|
||||
POINT_OPEN_STRING = "<0x02>" # <point>
|
||||
POINT_CLOSE_STRING = "<0x03>" # </point>
|
||||
|
||||
TEXT_REPR_BBOX_OPEN = "<box>"
|
||||
TEXT_REPR_BBOX_CLOSE = "</box>"
|
||||
TEXT_REPR_POINT_OPEN = "<point>"
|
||||
TEXT_REPR_POINT_CLOSE = "</point>"
|
||||
|
||||
TOKEN_BBOX_OPEN_STRING = BBOX_OPEN_STRING = "<0x00>" # <bbox>
|
||||
BBOX_CLOSE_STRING = "<0x01>" # </bbox>
|
||||
TOKEN_BBOX_CLOSE_STRING = TOKEN_POINT_OPEN_STRING = POINT_OPEN_STRING = "<0x02>" # <point>
|
||||
TOKEN_POINT_CLOSE_STRING = POINT_CLOSE_STRING = "<0x03>" # </point>
|
||||
TOKEN_BBOX_OPEN_STRING = "<0x00>" # <bbox>
|
||||
TOKEN_BBOX_CLOSE_STRING = "<0x01>" # </bbox>
|
||||
TOKEN_POINT_OPEN_STRING = "<0x02>" # <point>
|
||||
TOKEN_POINT_CLOSE_STRING = "<0x03>" # </point>
|
||||
BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>
|
||||
|
||||
|
||||
@ -87,18 +92,16 @@ def construct_full_unpacked_stream(
|
||||
|
||||
all_bi_stream = []
|
||||
|
||||
for bi in range(batch_size):
|
||||
for batch_index in range(batch_size):
|
||||
all_si_stream = []
|
||||
|
||||
# First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence
|
||||
# and append to lists. We use lists rather than tensors because each subsequence is variable-sized.
|
||||
for si in range(num_sub_sequences):
|
||||
image_adjustment = image_tokens[bi][si]
|
||||
si_stream = torch.cat([image_adjustment, input_stream[bi, si]], dim=0)
|
||||
num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[bi][si]
|
||||
|
||||
all_si_stream.append(si_stream[:num_real_tokens])
|
||||
# Combine all subsequences for this batch entry. Still using a list because each batch entry is variable-sized.
|
||||
# TODO Remove this logic in a subsequent release since subsequences are not supported.
|
||||
image_adjustment = image_tokens[batch_index][0]
|
||||
subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0)
|
||||
num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0]
|
||||
all_si_stream.append(subsequence_stream[:num_real_tokens])
|
||||
all_bi_stream.append(torch.cat(all_si_stream, dim=0))
|
||||
|
||||
return all_bi_stream
|
||||
@ -137,7 +140,7 @@ def _segment_prompt_into_text_token_conversions(prompt: str) -> List:
|
||||
return prompt_text_list
|
||||
|
||||
|
||||
def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenizer) -> List[int]:
|
||||
def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> List[int]:
|
||||
"""
|
||||
This function transforms the prompt in the following fashion:
|
||||
- <box> <point> and </box> </point> to their respective token mappings
|
||||
@ -161,7 +164,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
|
||||
for elem in prompt_text_list:
|
||||
if elem[1]:
|
||||
# This is a location, we need to tokenize it
|
||||
within_tag_tokenized = _transform_within_tags(elem[0], transformed_image, tokenizer)
|
||||
within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer)
|
||||
# Surround the text with the open and close tags
|
||||
transformed_prompt_tokens.extend(within_tag_tokenized)
|
||||
else:
|
||||
@ -169,7 +172,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
|
||||
return transformed_prompt_tokens
|
||||
|
||||
|
||||
def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]:
|
||||
def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> List[int]:
|
||||
"""
|
||||
Given a bounding box of the fashion <box>1, 2, 3, 4</box> | <point>1, 2</point> This function is responsible for
|
||||
converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas.
|
||||
@ -188,16 +191,14 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
|
||||
num_ints = [float(num.strip()) for num in num_int_strs]
|
||||
# scale to transformed image siz
|
||||
if len(num_ints) == 2:
|
||||
num_ints_translated = scale_point_to_transformed_image(
|
||||
x=num_ints[0], y=num_ints[1], transformed_image=transformed_image
|
||||
)
|
||||
num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor)
|
||||
elif len(num_ints) == 4:
|
||||
num_ints_translated = scale_bbox_to_transformed_image(
|
||||
top=num_ints[0],
|
||||
left=num_ints[1],
|
||||
bottom=num_ints[2],
|
||||
right=num_ints[3],
|
||||
transformed_image=transformed_image,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid number of ints: {len(num_ints)}")
|
||||
@ -209,7 +210,7 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
|
||||
def _tokenize_prompts_with_image_and_batch(
|
||||
tokenizer,
|
||||
prompts: List[List[str]],
|
||||
transformed_images: Optional[List[List["torch.Tensor"]]],
|
||||
scale_factors: Optional[List[List["torch.Tensor"]]],
|
||||
max_tokens_to_generate: int,
|
||||
max_position_embeddings: int,
|
||||
add_BOS: bool, # Same issue with types as above
|
||||
@ -223,13 +224,13 @@ def _tokenize_prompts_with_image_and_batch(
|
||||
"""
|
||||
|
||||
# If not tool use, tranform the coordinates while tokenizing
|
||||
if transformed_images is not None:
|
||||
if scale_factors is not None:
|
||||
transformed_prompt_tokens = []
|
||||
for prompt_seq, transformed_image_seq in zip(prompts, transformed_images):
|
||||
for prompt_seq, scale_factor_seq in zip(prompts, scale_factors):
|
||||
transformed_prompt_tokens.append(
|
||||
[
|
||||
_transform_coordinates_and_tokenize(prompt, transformed_image, tokenizer)
|
||||
for prompt, transformed_image in zip(prompt_seq, transformed_image_seq)
|
||||
_transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer)
|
||||
for prompt, scale_factor in zip(prompt_seq, scale_factor_seq)
|
||||
]
|
||||
)
|
||||
else:
|
||||
@ -260,7 +261,7 @@ def _tokenize_prompts_with_image_and_batch(
|
||||
# Number of tokens in the each sample of the batch.
|
||||
samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings)
|
||||
if max_prompt_len + max_tokens_to_generate > max_position_embeddings:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
|
||||
f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
|
||||
)
|
||||
@ -279,86 +280,30 @@ def _tokenize_prompts_with_image_and_batch(
|
||||
return prompts_tokens_tensor, prompts_length_tensor
|
||||
|
||||
|
||||
def original_to_transformed_h_coords(self, original_coords):
|
||||
# apply crop
|
||||
cropped_coords = (
|
||||
self._clamp_coords(original_coords, min_value=self.crop_top, max_value=self.crop_bottom) - self.crop_top
|
||||
)
|
||||
# apply scale
|
||||
scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_h / self.original_h)
|
||||
# apply pad
|
||||
return scaled_coords + self.padding_top
|
||||
# Simplified assuming self.crop_top = self.padding_top = 0
|
||||
def original_to_transformed_h_coords(original_coords, scale_h):
|
||||
return np.round(original_coords * scale_h).astype(np.int32)
|
||||
|
||||
|
||||
def original_to_transformed_w_coords(self, original_coords):
|
||||
# apply crop
|
||||
cropped_coords = (
|
||||
self._clamp_coords(original_coords, min_value=self.crop_left, max_value=self.crop_right) - self.crop_left
|
||||
)
|
||||
# apply scale
|
||||
scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_w / self.original_w)
|
||||
# apply pad
|
||||
return scaled_coords + self.padding_left
|
||||
# Simplified assuming self.crop_left = self.padding_left = 0
|
||||
def original_to_transformed_w_coords(original_coords, scale_w):
|
||||
return np.round(original_coords * scale_w).astype(np.int32)
|
||||
|
||||
|
||||
def scale_point_to_transformed_image(x: float, y: float) -> List[int]:
|
||||
x_scaled = original_to_transformed_w_coords(np.array([x / 2]))[0]
|
||||
y_scaled = original_to_transformed_h_coords(np.array([y / 2]))[0]
|
||||
def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> List[int]:
|
||||
x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0]
|
||||
y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0]
|
||||
return [x_scaled, y_scaled]
|
||||
|
||||
|
||||
def scale_bbox_to_transformed_image(top: float, left: float, bottom: float, right: float) -> List[int]:
|
||||
top_scaled = original_to_transformed_w_coords(np.array([top / 2]))[0]
|
||||
left_scaled = original_to_transformed_h_coords(np.array([left / 2]))[0]
|
||||
bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]))[0]
|
||||
right_scaled = original_to_transformed_h_coords(np.array([right / 2]))[0]
|
||||
return [top_scaled, left_scaled, bottom_scaled, right_scaled]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
||||
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
"""
|
||||
Return the maximum value across all indices of an iterable of values.
|
||||
"""
|
||||
return [max(values_i) for values_i in zip(*values)]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
def scale_bbox_to_transformed_image(
|
||||
top: float, left: float, bottom: float, right: float, scale_factor: float
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to make the pixel mask for.
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0]
|
||||
left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0]
|
||||
bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0]
|
||||
right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0]
|
||||
return [top_scaled, left_scaled, bottom_scaled, right_scaled]
|
||||
|
||||
|
||||
class FuyuProcessor(ProcessorMixin):
|
||||
@ -384,42 +329,148 @@ class FuyuProcessor(ProcessorMixin):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens_to_generate = 10
|
||||
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
|
||||
self.image_processor = FuyuImageProcessor()
|
||||
self.pad_token_id = 0
|
||||
self.dummy_image_index = -1
|
||||
|
||||
def _process_images(self, images):
|
||||
"""Utility function to preprocess the images and extract necessary information about original formats."""
|
||||
batch_images = []
|
||||
image_unpadded_heights = []
|
||||
image_unpadded_widths = []
|
||||
def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
|
||||
max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
|
||||
max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs)
|
||||
|
||||
for image in images:
|
||||
image = to_numpy_array(image)
|
||||
if not is_scaled_image(image):
|
||||
image = image / 255.0
|
||||
channel_dimension = infer_channel_dimension_format(image, 3)
|
||||
if channel_dimension == ChannelDimension.FIRST:
|
||||
width_index = 2
|
||||
height_index = 1
|
||||
elif channel_dimension == ChannelDimension.LAST:
|
||||
width_index = 1
|
||||
height_index = 0
|
||||
batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []}
|
||||
|
||||
image_unpadded_widths.append([image.shape[width_index]])
|
||||
image_unpadded_heights.append([image.shape[height_index]])
|
||||
for entry in model_inputs:
|
||||
for key, tensor in entry.items():
|
||||
if key == "input_ids":
|
||||
num_padding_tokens = max_length_input_ids - tensor.shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long),
|
||||
tensor,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
batched_inputs[key].append(padded_input_ids)
|
||||
|
||||
# Reproduct adept padding sampler
|
||||
padded_image = self.image_processor.apply_transformation(image)
|
||||
attention_mask = torch.cat(
|
||||
[torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)],
|
||||
dim=1,
|
||||
)
|
||||
batched_inputs["attention_mask"].append(attention_mask)
|
||||
|
||||
tensor_img = torch.Tensor(padded_image).permute(2, 0, 1)
|
||||
batch_images.append([tensor_img])
|
||||
elif key == "image_patches":
|
||||
# For image_patches, we don't pad but just append them to the list.
|
||||
batched_inputs[key].append(tensor)
|
||||
|
||||
return batch_images, torch.Tensor(image_unpadded_heights), torch.Tensor(image_unpadded_widths)
|
||||
else: # for image_patches_indices
|
||||
num_padding_indices = max_length_image_patch_indices - tensor.shape[1]
|
||||
padded_indices = torch.cat(
|
||||
[
|
||||
torch.full(
|
||||
(tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long
|
||||
),
|
||||
tensor,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
batched_inputs[key].append(padded_indices)
|
||||
batched_keys = ["input_ids", "image_patches_indices"]
|
||||
if return_attention_mask:
|
||||
batched_keys.append("attention_mask")
|
||||
for key in batched_keys:
|
||||
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
|
||||
|
||||
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
||||
return batched_inputs
|
||||
|
||||
def get_sample_encoding(
|
||||
self,
|
||||
prompts,
|
||||
scale_factors,
|
||||
image_unpadded_heights,
|
||||
image_unpadded_widths,
|
||||
image_placeholder_id,
|
||||
image_newline_id,
|
||||
tensor_batch_images,
|
||||
):
|
||||
image_present = torch.ones(1, 1, 1)
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
)
|
||||
# FIXME max_tokens_to_generate is embedded into this processor's call.
|
||||
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
tokenizer=self.tokenizer,
|
||||
prompts=prompts,
|
||||
scale_factors=scale_factors,
|
||||
max_tokens_to_generate=self.max_tokens_to_generate,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
add_BOS=True,
|
||||
add_beginning_of_answer_token=True,
|
||||
)
|
||||
image_padded_unpacked_tokens = construct_full_unpacked_stream(
|
||||
num_real_text_tokens=prompts_length,
|
||||
input_stream=prompt_tokens,
|
||||
image_tokens=model_image_input["image_input_ids"],
|
||||
batch_size=1,
|
||||
num_sub_sequences=self.subsequence_length,
|
||||
)
|
||||
# Construct inputs for image patch indices.
|
||||
unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
|
||||
num_real_text_tokens=prompts_length,
|
||||
input_stream=torch.full_like(prompt_tokens, -1),
|
||||
image_tokens=model_image_input["image_patch_indices_per_batch"],
|
||||
batch_size=1,
|
||||
num_sub_sequences=self.subsequence_length,
|
||||
)
|
||||
max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
|
||||
max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
|
||||
tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0]))
|
||||
|
||||
# Use same packing logic for the image patch indices.
|
||||
image_patch_input_indices = full_unpacked_stream_to_tensor(
|
||||
all_bi_tokens_to_place=[tokens_to_place],
|
||||
full_unpacked_stream=unpacked_image_patch_indices_per_batch,
|
||||
fill_value=-1,
|
||||
batch_size=1,
|
||||
new_seq_len=max_seq_len_batch,
|
||||
offset=0,
|
||||
)
|
||||
image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]])
|
||||
batch_encoding = {
|
||||
"input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
|
||||
"image_patches": image_patches_tensor,
|
||||
"image_patches_indices": image_patch_input_indices,
|
||||
}
|
||||
return batch_encoding
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text=None,
|
||||
images=None,
|
||||
add_special_tokens: bool = True,
|
||||
return_attention_mask: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> "FuyuBatchFeature":
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to
|
||||
encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
|
||||
FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
@ -433,130 +484,211 @@ class FuyuProcessor(ProcessorMixin):
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
[`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`.
|
||||
- **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model when
|
||||
`return_attention_mask=True`.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
# --- Check input validity ---
|
||||
if not return_attention_mask:
|
||||
raise ValueError("`return_attention_mask=False` is not supported for this model.")
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
raise ValueError("You have to specify either text or images. Both cannot be None.")
|
||||
if text is not None and images is None:
|
||||
logger.warning("You are processing a text with no associated image. Make sure it is intended.")
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
return text_encoding
|
||||
|
||||
if text is None and images is not None:
|
||||
logger.warning("You are processing an image with no associated text. Make sure it is intended.")
|
||||
prompts = [[""]]
|
||||
if text is not None and images is not None:
|
||||
if isinstance(text, str):
|
||||
prompts = [[text]]
|
||||
elif isinstance(text, list):
|
||||
prompts = [[text_seq] for text_seq in text]
|
||||
batch_images = []
|
||||
if isinstance(images, PIL.Image.Image):
|
||||
images = [images]
|
||||
if isinstance(images, list):
|
||||
batch_images, image_unpadded_heights, image_unpadded_widths = self._process_images(images)
|
||||
# image_unpadded_heights = image_unpadded_heights.unsqueeze(0)
|
||||
# image_unpadded_widths = image_unpadded_widths.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError("images must be a list of ndarrays or PIL Images to be processed.")
|
||||
|
||||
# Note: the original adept code has a handling of image_unpadded_h and w, but it doesn't seem to hold
|
||||
# when there are several different size subsequences per batch. The current implementation reflects
|
||||
# that limitation and should be documented.
|
||||
#
|
||||
self.subsequence_length = 1 # Each batch contains only one sequence.
|
||||
self.batch_size = len(batch_images)
|
||||
# FIXME max_tokens_to_generate is embedded into this processor's call.
|
||||
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
tokenizer=self.tokenizer,
|
||||
prompts=prompts,
|
||||
transformed_images=batch_images,
|
||||
max_tokens_to_generate=self.max_tokens_to_generate,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
add_BOS=True,
|
||||
add_beginning_of_answer_token=True,
|
||||
)
|
||||
# same so far
|
||||
# --- Preprocess images using self.image_processor ---
|
||||
|
||||
# This is 1 if there is an image per subsequence, else 0. [batch, 1, presence]
|
||||
# the remainder of current image processing logic assumes subsequence_size = 1.
|
||||
# Here it is OK as the model cannot handle > 1 subsequences
|
||||
# the image could be absent however and image presence should be inferred from user batch input
|
||||
# hence this code assumes the images are present. Use an assert?
|
||||
# FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
|
||||
image_encoding = self.image_processor.preprocess(images, return_tensors="pt")
|
||||
batch_images = image_encoding["images"]
|
||||
image_unpadded_heights = image_encoding["image_unpadded_heights"]
|
||||
image_unpadded_widths = image_encoding["image_unpadded_widths"]
|
||||
scale_factors = image_encoding["image_scale_factors"]
|
||||
self.subsequence_length = 1 # Each batch contains only one sequence.
|
||||
self.batch_size = len(batch_images)
|
||||
|
||||
image_present = torch.ones(self.batch_size, 1, 1)
|
||||
# --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
|
||||
|
||||
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
|
||||
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
|
||||
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
|
||||
model_image_input = self.image_processor.process_images_for_model_input(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_patch_dim_h=30,
|
||||
image_patch_dim_w=30,
|
||||
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
|
||||
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
|
||||
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
|
||||
|
||||
# --- Use self.image_processor again to obtain the full token ids and batch inputs ---
|
||||
all_encodings = []
|
||||
|
||||
for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip(
|
||||
prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images
|
||||
):
|
||||
sample_encoding = self.get_sample_encoding(
|
||||
prompts=[prompt],
|
||||
scale_factors=[scale_factor],
|
||||
image_unpadded_heights=torch.tensor([image_unpadded_height]),
|
||||
image_unpadded_widths=torch.tensor([image_unpadded_width]),
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
tensor_batch_images=tensor_batch_image.unsqueeze(0),
|
||||
)
|
||||
all_encodings.append(sample_encoding)
|
||||
batch_encoding = self._left_pad_inputs_with_attention_mask(
|
||||
model_inputs=all_encodings, return_attention_mask=return_attention_mask
|
||||
)
|
||||
return FuyuBatchFeature(data=batch_encoding)
|
||||
|
||||
image_padded_unpacked_tokens = construct_full_unpacked_stream(
|
||||
num_real_text_tokens=prompts_length,
|
||||
input_stream=prompt_tokens,
|
||||
image_tokens=model_image_input["image_input_ids"],
|
||||
batch_size=self.batch_size,
|
||||
num_sub_sequences=self.subsequence_length,
|
||||
)
|
||||
# Construct inputs for image patch indices.
|
||||
unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
|
||||
num_real_text_tokens=prompts_length,
|
||||
input_stream=torch.full_like(prompt_tokens, -1),
|
||||
image_tokens=model_image_input["image_patch_indices_per_batch"],
|
||||
batch_size=self.batch_size,
|
||||
num_sub_sequences=self.subsequence_length,
|
||||
)
|
||||
max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
|
||||
max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
|
||||
all_bi_tokens_to_place = []
|
||||
for bi in range(self.batch_size):
|
||||
tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[bi].shape[0]))
|
||||
all_bi_tokens_to_place.append(tokens_to_place)
|
||||
def post_process_box_coordinates(self, outputs, target_sizes=None):
|
||||
"""
|
||||
Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
|
||||
Coordinates will be returned in "box" format, with the following pattern:
|
||||
`<box>top, left, bottom, right</box>`
|
||||
|
||||
# Use same packing logic for the image patch indices.
|
||||
image_patch_input_indices = full_unpacked_stream_to_tensor(
|
||||
all_bi_tokens_to_place=all_bi_tokens_to_place,
|
||||
full_unpacked_stream=unpacked_image_patch_indices_per_batch,
|
||||
fill_value=-1,
|
||||
batch_size=self.batch_size,
|
||||
new_seq_len=max_seq_len_batch,
|
||||
offset=0,
|
||||
)
|
||||
Point coordinates are not supported yet.
|
||||
|
||||
image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]).unsqueeze(1)
|
||||
return {
|
||||
"input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
|
||||
"image_patches": image_patches_tensor[0][0].unsqueeze(0),
|
||||
"image_patches_indices": image_patch_input_indices,
|
||||
}
|
||||
Args:
|
||||
outputs ([`GenerateOutput`]):
|
||||
Raw outputs from `generate`.
|
||||
target_sizes (`torch.Tensor`, *optional*):
|
||||
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
|
||||
the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left
|
||||
to None, coordinates will not be rescaled.
|
||||
|
||||
Returns:
|
||||
`GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with
|
||||
boxed and possible rescaled coordinates.
|
||||
"""
|
||||
|
||||
def scale_factor_to_fit(original_size, target_size=None):
|
||||
height, width = original_size
|
||||
if target_size is None:
|
||||
max_height = self.image_processor.size["height"]
|
||||
max_width = self.image_processor.size["width"]
|
||||
else:
|
||||
max_height, max_width = target_size
|
||||
if width <= max_width and height <= max_height:
|
||||
return 1.0
|
||||
return min(max_height / height, max_width / width)
|
||||
|
||||
def find_delimiters_pair(tokens, start_token, end_token):
|
||||
start_id = self.tokenizer.convert_tokens_to_ids(start_token)
|
||||
end_id = self.tokenizer.convert_tokens_to_ids(end_token)
|
||||
|
||||
starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0]
|
||||
ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0]
|
||||
|
||||
if torch.any(starting_positions) and torch.any(ending_positions):
|
||||
return (starting_positions[0], ending_positions[0])
|
||||
return (None, None)
|
||||
|
||||
def tokens_to_boxes(tokens, original_size):
|
||||
while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != (
|
||||
None,
|
||||
None,
|
||||
):
|
||||
start, end = pair
|
||||
if end != start + 5:
|
||||
continue
|
||||
|
||||
# Retrieve transformed coordinates from tokens
|
||||
coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
|
||||
|
||||
# Scale back to original image size and multiply by 2
|
||||
scale = scale_factor_to_fit(original_size)
|
||||
top, left, bottom, right = [2 * int(float(c) / scale) for c in coords]
|
||||
|
||||
# Replace the IDs so they get detokenized right
|
||||
replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}"
|
||||
replacement = self.tokenizer.tokenize(replacement)[1:]
|
||||
replacement = self.tokenizer.convert_tokens_to_ids(replacement)
|
||||
replacement = torch.tensor(replacement).to(tokens)
|
||||
|
||||
tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
|
||||
return tokens
|
||||
|
||||
def tokens_to_points(tokens, original_size):
|
||||
while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != (
|
||||
None,
|
||||
None,
|
||||
):
|
||||
start, end = pair
|
||||
if end != start + 3:
|
||||
continue
|
||||
|
||||
# Retrieve transformed coordinates from tokens
|
||||
coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
|
||||
|
||||
# Scale back to original image size and multiply by 2
|
||||
scale = scale_factor_to_fit(original_size)
|
||||
x, y = [2 * int(float(c) / scale) for c in coords]
|
||||
|
||||
# Replace the IDs so they get detokenized right
|
||||
replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}"
|
||||
replacement = self.tokenizer.tokenize(replacement)[1:]
|
||||
replacement = self.tokenizer.convert_tokens_to_ids(replacement)
|
||||
replacement = torch.tensor(replacement).to(tokens)
|
||||
|
||||
tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
|
||||
return tokens
|
||||
|
||||
if target_sizes is None:
|
||||
target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs)
|
||||
elif target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
if len(outputs) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
|
||||
|
||||
results = []
|
||||
for seq, size in zip(outputs, target_sizes):
|
||||
seq = tokens_to_boxes(seq, size)
|
||||
seq = tokens_to_points(seq, size)
|
||||
results.append(seq)
|
||||
|
||||
return results
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
@ -24,7 +24,8 @@ if is_vision_available():
|
||||
@require_torchvision
|
||||
class TestFuyuImageProcessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = FuyuImageProcessor(target_height=160, target_width=320, padding_value=1.0)
|
||||
self.size = {"height": 160, "width": 320}
|
||||
self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0)
|
||||
self.batch_size = 3
|
||||
self.channels = 3
|
||||
self.height = 300
|
||||
@ -38,29 +39,25 @@ class TestFuyuImageProcessor(unittest.TestCase):
|
||||
self.sample_image_pil = Image.fromarray(self.sample_image)
|
||||
|
||||
def test_patches(self):
|
||||
expected_num_patches = self.processor.get_num_patches(
|
||||
img_h=self.height, img_w=self.width, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
|
||||
)
|
||||
expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width)
|
||||
|
||||
patches_final = self.processor.patchify_image(
|
||||
image=self.image_input, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
|
||||
)
|
||||
patches_final = self.processor.patchify_image(image=self.image_input)
|
||||
assert (
|
||||
patches_final.shape[1] == expected_num_patches
|
||||
), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}."
|
||||
|
||||
def test_scale_to_target_aspect_ratio(self):
|
||||
# (h:450, w:210) fitting (160, 320) -> (160, 210*160/450)
|
||||
scaled_image = self.processor._scale_to_target_aspect_ratio(self.sample_image)
|
||||
scaled_image = self.processor.resize(self.sample_image, size=self.size)
|
||||
self.assertEqual(scaled_image.shape[0], 160)
|
||||
self.assertEqual(scaled_image.shape[1], 74)
|
||||
|
||||
def test_apply_transformation_numpy(self):
|
||||
transformed_image = self.processor.apply_transformation(self.sample_image)
|
||||
self.assertEqual(transformed_image.shape[0], 160)
|
||||
self.assertEqual(transformed_image.shape[1], 320)
|
||||
transformed_image = self.processor.preprocess(self.sample_image).images[0][0]
|
||||
self.assertEqual(transformed_image.shape[1], 160)
|
||||
self.assertEqual(transformed_image.shape[2], 320)
|
||||
|
||||
def test_apply_transformation_pil(self):
|
||||
transformed_image = self.processor.apply_transformation(self.sample_image_pil)
|
||||
self.assertEqual(transformed_image.shape[0], 160)
|
||||
self.assertEqual(transformed_image.shape[1], 320)
|
||||
transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0]
|
||||
self.assertEqual(transformed_image.shape[1], 160)
|
||||
self.assertEqual(transformed_image.shape[2], 320)
|
||||
|
@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import AutoTokenizer, FuyuConfig, is_torch_available, is_vision_available
|
||||
from transformers import FuyuConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
from ...test_modeling_common import ids_tensor, random_attention_mask
|
||||
@ -14,7 +14,7 @@ if is_vision_available():
|
||||
|
||||
|
||||
if is_torch_available() and is_vision_available():
|
||||
from transformers import FuyuImageProcessor, FuyuProcessor
|
||||
from transformers import FuyuProcessor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -267,11 +267,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.pretrained_model_name = "huggingface/new_model_release_weights"
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name)
|
||||
image_processor = FuyuImageProcessor()
|
||||
|
||||
self.processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
self.pretrained_model_name = "adept/fuyu-8b"
|
||||
self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name)
|
||||
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
|
||||
self.bus_image_url = (
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
@ -280,9 +277,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
|
||||
@slow
|
||||
def test_model_8b_chat_greedy_generation_bus_captioning(self):
|
||||
EXPECTED_TEXT_COMPLETION = """A bus parked on the side of a road.|ENDOFTEXT|"""
|
||||
EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|"""
|
||||
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
|
||||
|
||||
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
|
||||
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
|
||||
text = self.processor.tokenizer.batch_decode(generated_tokens)
|
||||
@ -297,7 +293,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
|
||||
"""
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_model_8b_chat_greedy_generation_bus_color(self):
|
||||
EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|"
|
||||
text_prompt_bus_color = "What color is the bus?\n"
|
||||
@ -314,7 +310,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_model_8b_chat_greedy_generation_chart_vqa(self):
|
||||
# fmt: off
|
||||
EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",]
|
||||
@ -340,7 +336,7 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
|
||||
self.assertEqual(expected_text_completion, clean_sequence)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_model_8b_chat_greedy_generation_bounding_box(self):
|
||||
EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|"
|
||||
text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231
|
||||
|
@ -26,16 +26,14 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
""" """
|
||||
|
||||
def setUp(self):
|
||||
pretrained_model_name = "huggingface/pre_release_model"
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
||||
image_processor = FuyuImageProcessor()
|
||||
pretrained_model_name = "adept/fuyu-8b"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
||||
self.image_processor = FuyuImageProcessor()
|
||||
|
||||
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
text_prompt = "Generate a coco-style caption.\\n"
|
||||
self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
|
||||
self.text_prompt = "Generate a coco-style caption.\\n"
|
||||
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
|
||||
|
||||
self.one_image_bus_model_inputs = processor(text=text_prompt, images=bus_image_pil)
|
||||
self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
|
||||
|
||||
def test_fuyu_processing(self):
|
||||
"""
|
||||
@ -44,11 +42,119 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
# fmt: off
|
||||
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
|
||||
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
|
||||
|
||||
one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil)
|
||||
|
||||
# fmt: on
|
||||
torch.testing.assert_close(
|
||||
self.one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS
|
||||
torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS)
|
||||
torch.testing.assert_close(one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
|
||||
|
||||
def test_fuyu_processing_no_image(self):
|
||||
"""
|
||||
Test to check processor works with just text input
|
||||
"""
|
||||
processor_outputs = self.processor(text=self.text_prompt)
|
||||
tokenizer_outputs = self.tokenizer(self.text_prompt)
|
||||
self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"])
|
||||
|
||||
def test_fuyu_processing_no_text(self):
|
||||
"""
|
||||
Test to check processor works with just image input
|
||||
"""
|
||||
# fmt: off
|
||||
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([
|
||||
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
|
||||
14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26,
|
||||
27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
||||
41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
|
||||
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66,
|
||||
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
|
||||
81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93,
|
||||
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
|
||||
121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133,
|
||||
134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
|
||||
148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160,
|
||||
161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
|
||||
175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
|
||||
188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200,
|
||||
201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214,
|
||||
215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227,
|
||||
228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,
|
||||
-1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
|
||||
255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267,
|
||||
268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281,
|
||||
282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294,
|
||||
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]).to(torch.int64)
|
||||
# fmt: on
|
||||
|
||||
processor_outputs = self.processor(images=self.bus_image_pil)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all())
|
||||
|
||||
def test_fuyu_processing_multiple_image_sample(self):
|
||||
"""
|
||||
Test to check processor works with multiple image inputs for a single text input
|
||||
"""
|
||||
# fmt: off
|
||||
SINGLE_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
|
||||
SINGLE_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
|
||||
|
||||
SINGLE_RESIZED_IMAGE_PATCH_INPUTS = torch.Tensor([[ 0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9, 10, 11, -1, 12, 13, 14, -1, 15, 16, 17, -1, 18, 19, 20, -1, 21, 22, 23, -1, 24, 25, 26, -1, 27, 28, 29, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])
|
||||
SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[ 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122]])
|
||||
# fmt: on
|
||||
|
||||
# Batch of two images - equally sized
|
||||
images = [self.bus_image_pil, self.bus_image_pil]
|
||||
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
|
||||
self.assertTrue(
|
||||
(
|
||||
processor_outputs["image_patches_indices"]
|
||||
== torch.cat([SINGLE_IMAGE_PATCH_INPUTS, SINGLE_IMAGE_PATCH_INPUTS], dim=0)
|
||||
).all()
|
||||
)
|
||||
torch.testing.assert_close(self.one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
|
||||
self.assertTrue(
|
||||
(
|
||||
processor_outputs["input_ids"]
|
||||
== torch.cat([SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, SINGLE_PADDED_UNPACKED_TOKEN_INPUTS], dim=0)
|
||||
).all()
|
||||
)
|
||||
|
||||
# Processes single images with different sizes as expected
|
||||
images = [self.bus_image_pil]
|
||||
processor_outputs = self.processor(text=self.text_prompt, images=images)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all())
|
||||
|
||||
images = [self.bus_image_pil.resize((64, 300))]
|
||||
processor_outputs = self.processor(text=self.text_prompt, images=images)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all())
|
||||
|
||||
# Batch of two images - different sizes. Left-pads the smaller image inputs
|
||||
images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))]
|
||||
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
|
||||
padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1]
|
||||
padded_single_resized_image_patch = torch.cat(
|
||||
[torch.ones([1, padding_len_patch]) * -1, SINGLE_RESIZED_IMAGE_PATCH_INPUTS], dim=1
|
||||
)
|
||||
expected_image_patch_inputs = torch.cat([SINGLE_IMAGE_PATCH_INPUTS, padded_single_resized_image_patch], dim=0)
|
||||
|
||||
padding_len_token = (
|
||||
SINGLE_PADDED_UNPACKED_TOKEN_INPUTS.shape[1] - SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS.shape[1]
|
||||
)
|
||||
padded_single_resized_padded_unpacked_token_inputs = torch.cat(
|
||||
[torch.zeros([1, padding_len_token]), SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS], dim=1
|
||||
)
|
||||
expected_padded_unpacked_token_inputs = torch.cat(
|
||||
[SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, padded_single_resized_padded_unpacked_token_inputs], dim=0
|
||||
)
|
||||
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all())
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -97,7 +203,6 @@ class TestProcessImagesForModelInput(unittest.TestCase):
|
||||
"""
|
||||
Adding a mix of present and absent images.
|
||||
"""
|
||||
self.image_processor = FuyuImageProcessor()
|
||||
|
||||
self.image_input = torch.randn([1, 1, 3, 64, 64])
|
||||
self.image_present = torch.tensor([[1]])
|
||||
@ -108,19 +213,19 @@ class TestProcessImagesForModelInput(unittest.TestCase):
|
||||
self.image_placeholder_id = 999
|
||||
self.image_newline_id = 888
|
||||
self.variable_sized = True
|
||||
self.image_processor = FuyuImageProcessor(
|
||||
patch_size={"height": self.image_patch_dim_h, "width": self.image_patch_dim_w}
|
||||
)
|
||||
|
||||
def test_process_images_for_model_input_fixed_sized(self):
|
||||
self.variable_sized = False
|
||||
result = self.image_processor.process_images_for_model_input(
|
||||
result = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=self.image_input,
|
||||
image_present=self.image_present,
|
||||
image_unpadded_h=self.image_unpadded_h,
|
||||
image_unpadded_w=self.image_unpadded_w,
|
||||
image_patch_dim_h=self.image_patch_dim_h,
|
||||
image_patch_dim_w=self.image_patch_dim_w,
|
||||
image_placeholder_id=self.image_placeholder_id,
|
||||
image_newline_id=self.image_newline_id,
|
||||
variable_sized=self.variable_sized,
|
||||
)
|
||||
print(result["images"][0][0])
|
||||
self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64]))
|
||||
|
Loading…
Reference in New Issue
Block a user