mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Change Qwen2_VL image processors to have init and call accept the same kwargs (#36207)
Change qwen2VL image processors to have init and call accept the same kwargs
This commit is contained in:
parent
65b8e38aac
commit
1c287aecfc
@ -91,6 +91,8 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
@ -122,6 +124,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
@ -129,14 +132,27 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
min_pixels: int = 56 * 56,
|
||||
max_pixels: int = 28 * 28 * 1280,
|
||||
min_pixels: int = None,
|
||||
max_pixels: int = None,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
else:
|
||||
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
self.min_pixels = size["shortest_edge"]
|
||||
self.max_pixels = size["longest_edge"]
|
||||
self.size = size
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
@ -144,24 +160,26 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.merge_size = merge_size
|
||||
self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: Union[ImageInput, VideoInput],
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
patch_size: int = None,
|
||||
temporal_patch_size: int = None,
|
||||
merge_size: int = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@ -176,6 +194,8 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
Optional list of dictionaries containing additional information about vision inputs.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
@ -188,6 +208,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
@ -226,9 +252,9 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=self.patch_size * self.merge_size,
|
||||
min_pixels=self.min_pixels,
|
||||
max_pixels=self.max_pixels,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=size["shortest_edge"],
|
||||
max_pixels=size["longest_edge"],
|
||||
)
|
||||
image = resize(
|
||||
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
||||
@ -248,26 +274,26 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
patches = np.array(processed_images)
|
||||
if data_format == ChannelDimension.LAST:
|
||||
patches = patches.transpose(0, 3, 1, 2)
|
||||
if patches.shape[0] % self.temporal_patch_size != 0:
|
||||
repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0)
|
||||
if patches.shape[0] % temporal_patch_size != 0:
|
||||
repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
|
||||
patches = np.concatenate([patches, repeats], axis=0)
|
||||
channel = patches.shape[1]
|
||||
grid_t = patches.shape[0] // self.temporal_patch_size
|
||||
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
||||
grid_t = patches.shape[0] // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
patches = patches.reshape(
|
||||
grid_t,
|
||||
self.temporal_patch_size,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
grid_w // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
||||
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, (grid_t, grid_h, grid_w)
|
||||
@ -278,12 +304,17 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
videos: VideoInput = None,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
min_pixels: int = None,
|
||||
max_pixels: int = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
patch_size: int = None,
|
||||
temporal_patch_size: int = None,
|
||||
merge_size: int = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
@ -316,6 +347,16 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
|
||||
The min pixels of the image to resize the image.
|
||||
max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
@ -338,14 +379,29 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
if size is not None:
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
min_pixels = size["shortest_edge"]
|
||||
else:
|
||||
size = self.size
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
|
||||
merge_size = merge_size if merge_size is not None else self.merge_size
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
if images is not None:
|
||||
@ -375,12 +431,16 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
patches, image_grid_thw = self._preprocess(
|
||||
image,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
data_format=data_format,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
@ -397,12 +457,16 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
patches, video_grid_thw = self._preprocess(
|
||||
images,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
data_format=data_format,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
|
@ -105,13 +105,26 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
merge_size = 2
|
||||
min_pixels = 56 * 56
|
||||
max_pixels = 28 * 28 * 1280
|
||||
valid_kwargs = DefaultFastImageProcessorKwargs
|
||||
min_pixels = None
|
||||
max_pixels = None
|
||||
valid_kwargs = Qwen2VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
size = kwargs.pop("size", None)
|
||||
min_pixels = kwargs.pop("min_pixels", None)
|
||||
max_pixels = kwargs.pop("max_pixels", None)
|
||||
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
else:
|
||||
size = self.size
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
|
||||
super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@ -124,6 +137,9 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
@ -138,6 +154,8 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
Optional list of dictionaries containing additional information about vision inputs.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
@ -150,6 +168,12 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
@ -178,9 +202,9 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=self.patch_size * self.merge_size,
|
||||
min_pixels=self.min_pixels,
|
||||
max_pixels=self.max_pixels,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=size["shortest_edge"],
|
||||
max_pixels=size["longest_edge"],
|
||||
)
|
||||
stacked_images = F.resize(
|
||||
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
|
||||
@ -201,28 +225,28 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
patches = torch.stack(processed_images, dim=0)
|
||||
if patches.shape[0] % self.temporal_patch_size != 0:
|
||||
repeats = patches[-1].unsqueeze(0).repeat(self.temporal_patch_size - 1, 1, 1, 1)
|
||||
if patches.shape[0] % temporal_patch_size != 0:
|
||||
repeats = patches[-1].unsqueeze(0).repeat(temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=0)
|
||||
|
||||
channel = patches.shape[1]
|
||||
grid_t = patches.shape[0] // self.temporal_patch_size
|
||||
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
||||
grid_t = patches.shape[0] // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
|
||||
patches = patches.view(
|
||||
grid_t,
|
||||
self.temporal_patch_size,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
grid_w // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
||||
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, (grid_t, grid_h, grid_w)
|
||||
@ -239,6 +263,11 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
min_pixels: int = None,
|
||||
max_pixels: int = None,
|
||||
patch_size: int = None,
|
||||
temporal_patch_size: int = None,
|
||||
merge_size: int = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
@ -257,8 +286,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
@ -273,6 +301,16 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
|
||||
The min pixels of the image to resize the image.
|
||||
max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
@ -296,6 +334,18 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
device (`torch.device`, *optional*):
|
||||
The device to process the images on. If unset, the device is inferred from the input images.
|
||||
"""
|
||||
if size is not None:
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
min_pixels = size["shortest_edge"]
|
||||
else:
|
||||
size = self.size
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -304,6 +354,9 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
|
||||
merge_size = merge_size if merge_size is not None else self.merge_size
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
# Make hashable for cache
|
||||
@ -351,6 +404,9 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
@ -374,6 +430,9 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
|
Loading…
Reference in New Issue
Block a user