Add param_name to size_dict logs & tidy (#20205)

This commit is contained in:
amyeroberts 2022-11-15 10:52:58 +00:00 committed by GitHub
parent f1e8c48c5e
commit 55ba31908a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 99 additions and 63 deletions

View File

@ -440,11 +440,50 @@ class BaseImageProcessor(ImageProcessingMixin):
raise NotImplementedError("Each image processor must implement its own preprocess method")
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"})
def is_valid_size_dict(size_dict):
if not isinstance(size_dict, dict):
return False
size_dict_keys = set(size_dict.keys())
for allowed_keys in VALID_SIZE_DICT_KEYS:
if size_dict_keys == allowed_keys:
return True
return False
def convert_to_size_dict(
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
):
# By default, if size is an int we assume it represents a tuple of (size, size).
if isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
return {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of
# the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
size_dict = {"shortest_edge": size}
if max_size is not None:
size_dict["longest_edge"] = max_size
return size_dict
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
elif isinstance(size, (tuple, list)) and height_width_order:
return {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
return {"height": size[1], "width": size[0]}
raise ValueError(f"Could not convert size input to size dict: {size}")
def get_size_dict(
size: Union[int, Iterable[int], Dict[str, int]] = None,
max_size: Optional[int] = None,
height_width_order: bool = True,
default_to_square: bool = True,
param_name="size",
) -> dict:
"""
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
@ -467,40 +506,19 @@ def get_size_dict(
default_to_square (`bool`, *optional*, defaults to `True`):
If `size` is an int, whether to default to a square image or not.
"""
# If a dict is passed, we check if it's a valid size dict and then return it.
if isinstance(size, dict):
size_keys = set(size.keys())
if (
size_keys != set(["height", "width"])
and size_keys != set(["shortest_edge"])
and size_keys != set(["shortest_edge", "longest_edge"])
):
raise ValueError(
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
)
return size
if not isinstance(size, dict):
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
logger.info(
"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
" Converted to {size_dict}.",
)
else:
size_dict = size
# By default, if size is an int we assume it represents a tuple of (size, size).
elif isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
size_dict = {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
if max_size is not None:
size_dict = {"shortest_edge": size, "longest_edge": max_size}
else:
size_dict = {"shortest_edge": size}
elif isinstance(size, (tuple, list)) and height_width_order:
size_dict = {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
size_dict = {"height": size[1], "width": size[0]}
logger.info(
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
)
if not is_valid_size_dict(size_dict):
raise ValueError(
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
)
return size_dict

View File

@ -118,7 +118,7 @@ class BeitImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
@ -152,7 +152,7 @@ class BeitImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize(
@ -178,7 +178,7 @@ class BeitImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -406,11 +406,11 @@ class BeitImageProcessor(BaseImageProcessor):
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize

View File

@ -114,7 +114,7 @@ class CLIPImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -176,6 +176,8 @@ class CLIPImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys (height, width). Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -285,11 +287,11 @@ class CLIPImageProcessor(BaseImageProcessor):
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize

View File

@ -97,7 +97,7 @@ class DeiTImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -158,6 +158,8 @@ class DeiTImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -272,7 +274,7 @@ class DeiTImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]

View File

@ -253,12 +253,12 @@ class FlavaImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_size = get_size_dict(codebook_size)
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
codebook_crop_size = get_size_dict(codebook_crop_size)
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
self.do_resize = do_resize
self.size = size
@ -360,6 +360,8 @@ class FlavaImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -580,7 +582,7 @@ class FlavaImageProcessor(BaseImageProcessor):
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
@ -612,7 +614,7 @@ class FlavaImageProcessor(BaseImageProcessor):
)
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
codebook_size = get_size_dict(codebook_size)
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
codebook_rescale_factor = (
@ -622,7 +624,7 @@ class FlavaImageProcessor(BaseImageProcessor):
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
)
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
codebook_crop_size = get_size_dict(codebook_crop_size)
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
codebook_do_map_pixels = (
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
)

View File

@ -105,7 +105,7 @@ class LevitImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -182,6 +182,8 @@ class LevitImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size dict must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -299,7 +301,7 @@ class LevitImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]

View File

@ -109,7 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
@ -169,6 +169,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -286,7 +288,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize

View File

@ -123,7 +123,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -182,6 +182,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -280,7 +282,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]

View File

@ -99,7 +99,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
@ -141,7 +141,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
"""
size = self.size if size is None else size
size = get_size_dict(size)
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
height, width = get_image_size(image)
min_dim = min(height, width)
@ -278,7 +278,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
"""
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)

View File

@ -122,7 +122,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -218,6 +218,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"size must contain 'height' and 'width' as keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -335,7 +337,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images):
images = [images]

View File

@ -167,6 +167,8 @@ class SegformerImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(

View File

@ -121,7 +121,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
@ -157,7 +157,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
elif "height" in size and "width" in size:
@ -186,6 +186,8 @@ class VideoMAEImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size must have 'height' and 'width' as keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
@ -346,7 +348,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
if not valid_images(videos):
raise ValueError(