mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[llava] one pixel is missing from padding when length is odd (#37819)
* [fix] one pixel should be added when length is odd * [fix] add vision_aspect_ratio args & typo * [fix] style * [fix] do not fix fast file directly * [fix] convert using modular * remove duplicate codes * match unpad logic with pad logic * test odd-sized images for llava & aria * test unpad odd-sized padding for llava family * fix style * add kwarg to onvision modular * move vision_aspect_ratio from image_processor to processor (llava_onevision)
This commit is contained in:
parent
9981214d32
commit
acded47fe7
7
setup.py
7
setup.py
@ -466,7 +466,12 @@ setup(
|
||||
package_data={"": ["**/*.cu", "**/*.cpp", "**/*.cuh", "**/*.h", "**/*.pyx", "py.typed"]},
|
||||
zip_safe=False,
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["transformers=transformers.commands.transformers_cli:main", "transformers-cli=transformers.commands.transformers_cli:main_cli"]},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"transformers=transformers.commands.transformers_cli:main",
|
||||
"transformers-cli=transformers.commands.transformers_cli:main_cli",
|
||||
]
|
||||
},
|
||||
python_requires=">=3.9.0",
|
||||
install_requires=list(install_requires),
|
||||
classifiers=[
|
||||
|
@ -18,12 +18,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
|
||||
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
@ -71,23 +70,6 @@ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> Li
|
||||
return patches
|
||||
|
||||
|
||||
def _get_patch_output_size(image, target_resolution, input_data_format):
|
||||
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||||
target_height, target_width = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class AriaImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
A vision processor for the Aria model that handles image preprocessing.
|
||||
@ -375,7 +357,7 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
np.array: The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||||
@ -389,12 +371,12 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -20,7 +19,7 @@ import numpy as np
|
||||
from ...activations import ACT2FN
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
|
||||
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
@ -461,23 +460,6 @@ class AriaProjector(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def _get_patch_output_size(image, target_resolution, input_data_format):
|
||||
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||||
target_height, target_width = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class AriaImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
A vision processor for the Aria model that handles image preprocessing.
|
||||
@ -765,7 +747,7 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
np.array: The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||||
@ -779,12 +761,12 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -14,12 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for LLaVa-NeXT."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
|
||||
from ...image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
get_size_dict,
|
||||
select_best_resolution,
|
||||
)
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
convert_to_rgb,
|
||||
@ -99,23 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
|
||||
return result
|
||||
|
||||
|
||||
def _get_patch_output_size(image, target_resolution, input_data_format):
|
||||
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||||
target_height, target_width = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
|
||||
@ -429,7 +417,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
np.array: The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||||
@ -443,12 +431,12 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -102,8 +102,8 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]) -> BatchFeature:
|
||||
@ -164,10 +164,10 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -139,14 +139,14 @@ def unpad_image(tensor, original_size):
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(round(original_height * scale_factor, 7))
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
||||
padding, r = divmod(current_height - new_height, 2)
|
||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(round(original_width * scale_factor, 7))
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
||||
padding, r = divmod(current_width - new_width, 2)
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
@ -262,14 +262,14 @@ def unpad_image(tensor, original_size):
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(round(original_height * scale_factor, 7))
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
||||
padding, r = divmod(current_height - new_height, 2)
|
||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(round(original_width * scale_factor, 7))
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
||||
padding, r = divmod(current_width - new_width, 2)
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
@ -14,12 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for LLaVa-Onevision."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
|
||||
from ...image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
get_size_dict,
|
||||
select_best_resolution,
|
||||
)
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
convert_to_rgb,
|
||||
@ -99,24 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
|
||||
return result
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
|
||||
def _get_patch_output_size(image, target_resolution, input_data_format):
|
||||
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||||
target_height, target_width = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
|
||||
@ -151,8 +138,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
@ -321,7 +308,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
np.array: The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||||
@ -336,12 +323,12 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -84,8 +84,8 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
|
||||
@ -146,10 +146,10 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -140,14 +140,14 @@ def unpad_image(tensor, original_size):
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(round(original_height * scale_factor, 7))
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
||||
padding, r = divmod(current_height - new_height, 2)
|
||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(round(original_width * scale_factor, 7))
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
||||
padding, r = divmod(current_width - new_width, 2)
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
@ -70,6 +70,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
@ -79,6 +81,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"vision_aspect_ratio",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
@ -94,6 +97,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
chat_template=None,
|
||||
image_token="<image>",
|
||||
video_token="<video>",
|
||||
vision_aspect_ratio="anyres_max_9",
|
||||
**kwargs,
|
||||
):
|
||||
self.num_image_tokens = num_image_tokens
|
||||
@ -110,6 +114,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.vision_aspect_ratio = vision_aspect_ratio
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
@ -264,7 +269,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
ratio = math.sqrt(current_height * current_width / (9 * patches_height**2))
|
||||
max_num_patches = int(self.vision_aspect_ratio.strip("anyres_max_"))
|
||||
ratio = math.sqrt(current_height * current_width / (max_num_patches * patches_height**2))
|
||||
if ratio > 1.1:
|
||||
unpadded_features = int(current_height // ratio) * int(current_width // ratio)
|
||||
newline_features = int(current_height // ratio)
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.image_utils import ChannelDimension, PILImageResampling
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -264,3 +264,41 @@ class AriaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
def test_pad_for_patching(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == self.fast_image_processing_class:
|
||||
numpify = False
|
||||
torchify = True
|
||||
input_data_format = image_processing_class.data_format
|
||||
else:
|
||||
numpify = True
|
||||
torchify = False
|
||||
input_data_format = ChannelDimension.LAST
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# Create odd-sized images
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(
|
||||
batch_size=1,
|
||||
max_resolution=400,
|
||||
num_images=1,
|
||||
equal_resolution=True,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)[0][0]
|
||||
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
|
||||
|
||||
# Test odd-width
|
||||
image_shape = (400, 601)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
||||
# Test odd-height
|
||||
image_shape = (503, 400)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
|
||||
from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
@ -230,3 +230,38 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# Image processor should return same pixel values, independently of ipnut format
|
||||
self.assertTrue((encoded_images_nested == encoded_images).all())
|
||||
|
||||
def test_pad_for_patching(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == self.fast_image_processing_class:
|
||||
numpify = False
|
||||
torchify = True
|
||||
input_data_format = image_processing_class.data_format
|
||||
else:
|
||||
numpify = True
|
||||
torchify = False
|
||||
input_data_format = ChannelDimension.LAST
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# Create odd-sized images
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=True,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)[0]
|
||||
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
|
||||
|
||||
# Test odd-width
|
||||
image_shape = (400, 601)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
||||
# Test odd-height
|
||||
image_shape = (503, 400)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
@ -48,7 +48,7 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -288,6 +288,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
def test_unpad_image(self):
|
||||
original_size = (400, 400)
|
||||
|
||||
# Test case width is padded
|
||||
pixel_values = floats_tensor([3, 400, 601])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
# Test case height is padded
|
||||
pixel_values = floats_tensor([3, 503, 400])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(-1,),
|
||||
|
@ -47,6 +47,8 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -302,6 +304,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
def test_unpad_image(self):
|
||||
original_size = (400, 400)
|
||||
|
||||
# Test case width is padded
|
||||
pixel_values = floats_tensor([3, 400, 601])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
# Test case height is padded
|
||||
pixel_values = floats_tensor([3, 503, 400])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(-1,),
|
||||
|
@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
@ -305,3 +305,38 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
|
||||
) # FIXME yoni
|
||||
def test_can_compile_fast_image_processor(self):
|
||||
pass
|
||||
|
||||
def test_pad_for_patching(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == self.fast_image_processing_class:
|
||||
numpify = False
|
||||
torchify = True
|
||||
input_data_format = image_processing_class.data_format
|
||||
else:
|
||||
numpify = True
|
||||
torchify = False
|
||||
input_data_format = ChannelDimension.LAST
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# Create odd-sized images
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=True,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)[0]
|
||||
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
|
||||
|
||||
# Test odd-width
|
||||
image_shape = (400, 601)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
||||
# Test odd-height
|
||||
image_shape = (503, 400)
|
||||
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
|
||||
encoded_image_shape = (
|
||||
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
|
||||
)
|
||||
self.assertEqual(encoded_image_shape, image_shape)
|
||||
|
@ -48,6 +48,8 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -258,6 +260,19 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_unpad_image(self):
|
||||
original_size = (400, 400)
|
||||
|
||||
# Test case width is padded
|
||||
pixel_values = floats_tensor([3, 400, 601])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
# Test case height is padded
|
||||
pixel_values = floats_tensor([3, 503, 400])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(-1,),
|
||||
|
Loading…
Reference in New Issue
Block a user