mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add param_name to size_dict logs & tidy (#20205)
This commit is contained in:
parent
f1e8c48c5e
commit
55ba31908a
@ -440,11 +440,50 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
raise NotImplementedError("Each image processor must implement its own preprocess method")
|
||||
|
||||
|
||||
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"})
|
||||
|
||||
|
||||
def is_valid_size_dict(size_dict):
|
||||
if not isinstance(size_dict, dict):
|
||||
return False
|
||||
|
||||
size_dict_keys = set(size_dict.keys())
|
||||
for allowed_keys in VALID_SIZE_DICT_KEYS:
|
||||
if size_dict_keys == allowed_keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_to_size_dict(
|
||||
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
|
||||
):
|
||||
# By default, if size is an int we assume it represents a tuple of (size, size).
|
||||
if isinstance(size, int) and default_to_square:
|
||||
if max_size is not None:
|
||||
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
|
||||
return {"height": size, "width": size}
|
||||
# In other configs, if size is an int and default_to_square is False, size represents the length of
|
||||
# the shortest edge after resizing.
|
||||
elif isinstance(size, int) and not default_to_square:
|
||||
size_dict = {"shortest_edge": size}
|
||||
if max_size is not None:
|
||||
size_dict["longest_edge"] = max_size
|
||||
return size_dict
|
||||
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
|
||||
elif isinstance(size, (tuple, list)) and height_width_order:
|
||||
return {"height": size[0], "width": size[1]}
|
||||
elif isinstance(size, (tuple, list)) and not height_width_order:
|
||||
return {"height": size[1], "width": size[0]}
|
||||
|
||||
raise ValueError(f"Could not convert size input to size dict: {size}")
|
||||
|
||||
|
||||
def get_size_dict(
|
||||
size: Union[int, Iterable[int], Dict[str, int]] = None,
|
||||
max_size: Optional[int] = None,
|
||||
height_width_order: bool = True,
|
||||
default_to_square: bool = True,
|
||||
param_name="size",
|
||||
) -> dict:
|
||||
"""
|
||||
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
|
||||
@ -467,40 +506,19 @@ def get_size_dict(
|
||||
default_to_square (`bool`, *optional*, defaults to `True`):
|
||||
If `size` is an int, whether to default to a square image or not.
|
||||
"""
|
||||
# If a dict is passed, we check if it's a valid size dict and then return it.
|
||||
if isinstance(size, dict):
|
||||
size_keys = set(size.keys())
|
||||
if (
|
||||
size_keys != set(["height", "width"])
|
||||
and size_keys != set(["shortest_edge"])
|
||||
and size_keys != set(["shortest_edge", "longest_edge"])
|
||||
):
|
||||
raise ValueError(
|
||||
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
|
||||
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
|
||||
)
|
||||
return size
|
||||
if not isinstance(size, dict):
|
||||
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
|
||||
logger.info(
|
||||
"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
|
||||
" Converted to {size_dict}.",
|
||||
)
|
||||
else:
|
||||
size_dict = size
|
||||
|
||||
# By default, if size is an int we assume it represents a tuple of (size, size).
|
||||
elif isinstance(size, int) and default_to_square:
|
||||
if max_size is not None:
|
||||
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
|
||||
size_dict = {"height": size, "width": size}
|
||||
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
|
||||
elif isinstance(size, int) and not default_to_square:
|
||||
if max_size is not None:
|
||||
size_dict = {"shortest_edge": size, "longest_edge": max_size}
|
||||
else:
|
||||
size_dict = {"shortest_edge": size}
|
||||
elif isinstance(size, (tuple, list)) and height_width_order:
|
||||
size_dict = {"height": size[0], "width": size[1]}
|
||||
elif isinstance(size, (tuple, list)) and not height_width_order:
|
||||
size_dict = {"height": size[1], "width": size[0]}
|
||||
|
||||
logger.info(
|
||||
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
|
||||
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
|
||||
)
|
||||
if not is_valid_size_dict(size_dict):
|
||||
raise ValueError(
|
||||
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
|
||||
)
|
||||
return size_dict
|
||||
|
||||
|
||||
|
@ -118,7 +118,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"height": 256, "width": 256}
|
||||
size = get_size_dict(size)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
@ -152,7 +152,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
size = get_size_dict(size, default_to_square=True, param_name="size")
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
|
||||
return resize(
|
||||
@ -178,7 +178,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
size = get_size_dict(size, default_to_square=True, param_name="size")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -406,11 +406,11 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size)
|
||||
size = get_size_dict(size, default_to_square=True, param_name="size")
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
|
@ -114,7 +114,7 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -176,6 +176,8 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the keys (height, width). Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -285,11 +287,11 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
|
@ -97,7 +97,7 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"height": 256, "width": 256}
|
||||
size = get_size_dict(size)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -158,6 +158,8 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -272,7 +274,7 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size)
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
if not is_batched(images):
|
||||
images = [images]
|
||||
|
@ -253,12 +253,12 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
|
||||
codebook_size = get_size_dict(codebook_size)
|
||||
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
|
||||
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
|
||||
codebook_crop_size = get_size_dict(codebook_crop_size)
|
||||
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -360,6 +360,8 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -580,7 +582,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
@ -612,7 +614,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
|
||||
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
|
||||
codebook_size = get_size_dict(codebook_size)
|
||||
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
|
||||
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
|
||||
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
|
||||
codebook_rescale_factor = (
|
||||
@ -622,7 +624,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
|
||||
)
|
||||
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
|
||||
codebook_crop_size = get_size_dict(codebook_crop_size)
|
||||
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
|
||||
codebook_do_map_pixels = (
|
||||
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
|
||||
)
|
||||
|
@ -105,7 +105,7 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -182,6 +182,8 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"Size dict must have keys 'height' and 'width'. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -299,7 +301,7 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
if not is_batched(images):
|
||||
images = [images]
|
||||
|
@ -109,7 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 256}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
@ -169,6 +169,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -286,7 +288,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
|
@ -123,7 +123,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -182,6 +182,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -280,7 +282,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
if not is_batched(images):
|
||||
images = [images]
|
||||
|
@ -99,7 +99,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size)
|
||||
|
||||
@ -141,7 +141,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
size = self.size if size is None else size
|
||||
size = get_size_dict(size)
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
height, width = get_image_size(image)
|
||||
min_dim = min(height, width)
|
||||
@ -278,7 +278,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size)
|
||||
|
@ -122,7 +122,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -218,6 +218,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"size must contain 'height' and 'width' as keys. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -335,7 +337,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
if not is_batched(images):
|
||||
images = [images]
|
||||
|
@ -167,6 +167,8 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
|
@ -121,7 +121,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
@ -157,7 +157,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" in size:
|
||||
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
|
||||
elif "height" in size and "width" in size:
|
||||
@ -186,6 +186,8 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"Size must have 'height' and 'width' as keys. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
|
||||
def rescale(
|
||||
@ -346,7 +348,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
if not valid_images(videos):
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user