smol improvements to support more flexible usage (#34857)

* smol improvements to support more flexible usage

* ruff
This commit is contained in:
Andrés Marafioti 2024-11-22 16:34:38 +01:00 committed by GitHub
parent 42b36d7395
commit 861758e235
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,6 +38,7 @@ from ...utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
if is_vision_available():
@ -116,7 +117,6 @@ def _resize_output_size_scale_below_upper_bound(
def get_resize_output_image_size(
image,
resolution_max_side: int,
max_image_size: int = 1820,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
"""
@ -126,24 +126,18 @@ def get_resize_output_image_size(
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio, with a lower bound of `min_image_size`.
max_image_size (`int`, *optional*, defaults to 1820):
Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this
value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`.
input aspect ratio.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
The output size of the image after resizing.
"""
if resolution_max_side > max_image_size:
raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`")
height, width = get_image_size(image, channel_dim=input_data_format)
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the max_image_size
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size)
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
return height, width
@ -251,7 +245,7 @@ def convert_to_rgb(
data_format = input_data_format if data_format is None else data_format
mode = "P" if palette is not None else None
image = to_pil_image(image, image_mode=mode)
image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
if image.mode == "P" and palette is not None:
image.putpalette(palette)
@ -404,7 +398,7 @@ class Idefics3ImageProcessor(BaseImageProcessor):
image_mode = None
if image.ndim == 2 or image.shape[-1] == 1:
image_mode = "P"
image = to_pil_image(image, image_mode=image_mode)
image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)
resized_image = image.resize((size[1], size[0]), resample=resample)
resized_image = np.array(resized_image)
@ -754,6 +748,16 @@ class Idefics3ImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
# Extra channel dimension for grayscale images
if input_data_format in [ChannelDimension.LAST, None]:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]
if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
@ -764,18 +768,6 @@ class Idefics3ImageProcessor(BaseImageProcessor):
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
# Extra channel dimension for grayscale images
if input_data_format == ChannelDimension.LAST:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]
else:
raise ValueError(f"Invalid channel dimension format {input_data_format}.")
if do_resize:
images_list = [
[