mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add Got-OCR 2 Fast image processor and refactor slow one (#36185)
* refactor image processor slow got ocr * add working image processor fast * fix fast image processor, update doc * use one big loop for processing patches
This commit is contained in:
parent
51083d1bac
commit
2c5d038f92
@ -44,13 +44,14 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
|
||||
>>> inputs = processor(image, return_tensors="pt").to(device)
|
||||
>>> inputs = processor(image, return_tensors="pt", device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -68,15 +69,16 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
|
||||
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
|
||||
|
||||
>>> inputs = processor([image1, image2], return_tensors="pt").to(device)
|
||||
>>> inputs = processor([image1, image2], return_tensors="pt", device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -96,13 +98,14 @@ GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -124,14 +127,15 @@ Here is an example of how to process multiple pages at once:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
|
||||
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
|
||||
>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device)
|
||||
>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True, device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -153,13 +157,14 @@ Here is an example of how to process cropped patches:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device)
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3, device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -179,13 +184,14 @@ GOT supports interactive OCR, where the user can specify the region to be recogn
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
|
||||
>>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
|
||||
>>> inputs = processor(image, return_tensors="pt", color="green", device=device).to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -206,14 +212,15 @@ Here is an example of how to process sheet music:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
>>> import verovio
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
|
||||
|
||||
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)
|
||||
>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
|
||||
|
||||
>>> generate_ids = model.generate(
|
||||
... **inputs,
|
||||
@ -258,6 +265,10 @@ alt="drawing" width="600"/>
|
||||
|
||||
[[autodoc]] GotOcr2ImageProcessor
|
||||
|
||||
## GotOcr2ImageProcessorFast
|
||||
|
||||
[[autodoc]] GotOcr2ImageProcessorFast
|
||||
|
||||
## GotOcr2Processor
|
||||
|
||||
[[autodoc]] GotOcr2Processor
|
||||
|
@ -1330,6 +1330,7 @@ else:
|
||||
_import_structure["models.deit"].append("DeiTImageProcessorFast")
|
||||
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
|
||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
|
||||
_import_structure["models.llava"].append("LlavaImageProcessorFast")
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
|
||||
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
|
||||
@ -6526,6 +6527,7 @@ if TYPE_CHECKING:
|
||||
from .models.deit import DeiTImageProcessorFast
|
||||
from .models.depth_pro import DepthProImageProcessorFast
|
||||
from .models.detr import DetrImageProcessorFast
|
||||
from .models.got_ocr2 import GotOcr2ImageProcessorFast
|
||||
from .models.llava import LlavaImageProcessorFast
|
||||
from .models.llava_next import LlavaNextImageProcessorFast
|
||||
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
|
||||
|
@ -88,7 +88,7 @@ else:
|
||||
("fuyu", ("FuyuImageProcessor",)),
|
||||
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("glpn", ("GLPNImageProcessor",)),
|
||||
("got_ocr2", ("GotOcr2ImageProcessor",)),
|
||||
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("grounding-dino", ("GroundingDinoImageProcessor",)),
|
||||
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("hiera", ("BitImageProcessor",)),
|
||||
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_got_ocr2 import *
|
||||
from .image_processing_got_ocr2 import *
|
||||
from .image_processing_got_ocr2_fast import *
|
||||
from .modeling_got_ocr2 import *
|
||||
from .processing_got_ocr2 import *
|
||||
|
||||
|
@ -1,9 +1,3 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_got_ocr2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -18,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Image processor class for Got-OCR-2."""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@ -27,11 +21,9 @@ import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
_rescale_for_pil_conversion,
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
to_pil_image,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
@ -142,6 +134,15 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
crop_to_patches (`bool`, *optional*, defaults to `False`):
|
||||
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
|
||||
`preprocess` method.
|
||||
min_patches (`int`, *optional*, defaults to 1):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
|
||||
max_patches (`int`, *optional*, defaults to 12):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
@ -172,6 +173,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
crop_to_patches: bool = False,
|
||||
min_patches: int = 1,
|
||||
max_patches: int = 12,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
@ -187,6 +191,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.crop_to_patches = crop_to_patches
|
||||
self.min_patches = min_patches
|
||||
self.max_patches = max_patches
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
@ -249,6 +256,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
crop_to_patches: Optional[bool] = None,
|
||||
min_patches: Optional[int] = None,
|
||||
max_patches: Optional[int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
@ -274,6 +284,14 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
|
||||
Whether to crop the image to patches.
|
||||
min_patches (`int`, *optional*, defaults to `self.min_patches`):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`.
|
||||
max_patches (`int`, *optional*, defaults to `self.max_patches`):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
@ -308,6 +326,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
|
||||
min_patches = min_patches if min_patches is not None else self.min_patches
|
||||
max_patches = max_patches if max_patches is not None else self.max_patches
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
@ -353,40 +374,52 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
if crop_to_patches and max_patches > 1:
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
self.crop_image_to_patches(
|
||||
image,
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
patch_size=size,
|
||||
data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
num_patches = np.array([len(image) for image in images])
|
||||
images = [image for images_list in images for image in images_list]
|
||||
else:
|
||||
num_patches = np.array([1] * len(images))
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if do_resize:
|
||||
images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
images[i] = self.normalize(
|
||||
image=images[i],
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
encoded_outputs = BatchFeature(
|
||||
data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
return encoded_outputs
|
||||
|
||||
def crop_image_to_patches(
|
||||
self,
|
||||
image: ImageInput,
|
||||
images: np.ndarray,
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_thumbnail: bool = True,
|
||||
patch_size: Union[Tuple, int, dict] = None,
|
||||
return_numpy: bool = False,
|
||||
data_format: ChannelDimension = None,
|
||||
):
|
||||
"""
|
||||
@ -396,8 +429,8 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
|
||||
The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor.
|
||||
images (`np.ndarray`):
|
||||
The image to be cropped.
|
||||
min_patches (`int`):
|
||||
The minimum number of patches to be extracted from the image.
|
||||
max_patches (`int`):
|
||||
@ -406,24 +439,17 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
Whether to add a thumbnail image to the list of cropped patches.
|
||||
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
|
||||
The size of the output patches.
|
||||
return_numpy (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the cropped images as NumPy arrays.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The format of the image data. If `None`, the format is inferred from the input image.
|
||||
|
||||
Returns:
|
||||
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
|
||||
"""
|
||||
patch_size = patch_size if patch_size is not None else self.size
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
original_size = get_size_dict(image.size, height_width_order=False)
|
||||
do_rescale = False
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
do_rescale = _rescale_for_pil_conversion(image)
|
||||
image = to_pil_image(image, do_rescale=do_rescale)
|
||||
|
||||
if data_format is None:
|
||||
data_format = infer_channel_dimension_format(images)
|
||||
images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
|
||||
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
|
||||
original_height, original_width = original_size["height"], original_size["width"]
|
||||
original_height, original_width = images.shape[-2:]
|
||||
# find the closest aspect ratio to the target
|
||||
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
|
||||
@ -435,8 +461,12 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
num_blocks = num_columns * num_rows
|
||||
|
||||
# resize the image so that each patch is of patch_size
|
||||
resized_image = image.resize((target_width, target_height))
|
||||
|
||||
resized_image = self.resize(
|
||||
images,
|
||||
{"height": target_height, "width": target_width},
|
||||
data_format=ChannelDimension.FIRST,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
# split the image into patches
|
||||
processed_images = []
|
||||
for i in range(num_blocks):
|
||||
@ -449,33 +479,16 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
||||
(row + 1) * patch_size_height,
|
||||
)
|
||||
# split the image
|
||||
patch_image = resized_image.crop(box)
|
||||
patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
|
||||
patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
|
||||
processed_images.append(patch_image)
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((patch_size_width, patch_size_height))
|
||||
thumbnail_img = self.resize(
|
||||
images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
if return_numpy:
|
||||
processed_images_numpy = []
|
||||
for processed_image in processed_images:
|
||||
processed_image = np.array(processed_image)
|
||||
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
||||
# so we need to add it back if necessary.
|
||||
processed_image = (
|
||||
np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
|
||||
)
|
||||
# The image is always in channels last format after converting from a PIL image
|
||||
if data_format is not None:
|
||||
processed_image = to_channel_dimension_format(
|
||||
processed_image, data_format, input_channel_dim=ChannelDimension.LAST
|
||||
)
|
||||
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
||||
# rescale it back to the original range.
|
||||
processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
|
||||
processed_images_numpy.append(processed_image)
|
||||
processed_images = processed_images_numpy
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
|
@ -0,0 +1,257 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Got-OCR-2."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorInitKwargs,
|
||||
DefaultFastImageProcessorPreprocessKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from .image_processing_got_ocr2 import get_optimal_tiled_canvas
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class GotOcr2ImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
|
||||
crop_to_patches: Optional[bool]
|
||||
min_patches: Optional[int]
|
||||
max_patches: Optional[int]
|
||||
|
||||
|
||||
class GotOcr2ImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
|
||||
crop_to_patches: Optional[bool]
|
||||
min_patches: Optional[int]
|
||||
max_patches: Optional[int]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast GotOcr2 image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
crop_to_patches (`bool`, *optional*, defaults to `False`):
|
||||
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
|
||||
`preprocess` method.
|
||||
min_patches (`int`, *optional*, defaults to 1):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
|
||||
max_patches (`int`, *optional*, defaults to 12):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
""",
|
||||
)
|
||||
class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
crop_to_patches = False
|
||||
min_patches = 1
|
||||
max_patches = 12
|
||||
valid_init_kwargs = GotOcr2ImageProcessorInitKwargs
|
||||
valid_preprocess_kwargs = GotOcr2ImageProcessorPreprocessKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[GotOcr2ImageProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
crop_to_patches (`bool`, *optional*, defaults to `False`):
|
||||
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
|
||||
`preprocess` method.
|
||||
min_patches (`int`, *optional*, defaults to 1):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
|
||||
max_patches (`int`, *optional*, defaults to 12):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2ImageProcessorPreprocessKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def crop_image_to_patches(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_thumbnail: bool = True,
|
||||
patch_size: Union[Tuple, int, dict] = None,
|
||||
interpolation: Optional["F.InterpolationMode"] = None,
|
||||
):
|
||||
"""
|
||||
Crop the images to patches and return a list of cropped images.
|
||||
The number of patches and their grid arrangement are determined by the original image size,
|
||||
the target patch size and the minimum and maximum number of patches.
|
||||
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The images to be cropped.
|
||||
min_patches (`int`):
|
||||
The minimum number of patches to be extracted from the image.
|
||||
max_patches (`int`):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
use_thumbnail (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a thumbnail image to the list of cropped patches.
|
||||
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
|
||||
The size of the output patches.
|
||||
The format of the image data. If `None`, the format is inferred from the input image.
|
||||
|
||||
Returns:
|
||||
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
|
||||
"""
|
||||
patch_size_height, patch_size_width = patch_size.height, patch_size.width
|
||||
original_height, original_width = images.shape[-2:]
|
||||
# find the closest aspect ratio to the target
|
||||
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = patch_size_width * num_columns
|
||||
target_height = patch_size_height * num_rows
|
||||
num_blocks = num_columns * num_rows
|
||||
|
||||
# resize the image so that each patch is of patch_size
|
||||
resized_image = self.resize(
|
||||
images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
|
||||
)
|
||||
# split the image into patches
|
||||
processed_images = []
|
||||
for i in range(num_blocks):
|
||||
column = i % num_columns
|
||||
row = i // num_columns
|
||||
box = (
|
||||
column * patch_size_width,
|
||||
row * patch_size_height,
|
||||
(column + 1) * patch_size_width,
|
||||
(row + 1) * patch_size_height,
|
||||
)
|
||||
# split the image
|
||||
patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
|
||||
processed_images.append(patch_image)
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
|
||||
|
||||
return processed_images
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
crop_to_patches: bool,
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
if crop_to_patches:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
processed_images_grouped = {}
|
||||
num_patches = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
stacked_images = self.crop_image_to_patches(
|
||||
stacked_images,
|
||||
min_patches,
|
||||
max_patches,
|
||||
patch_size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0]
|
||||
images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
images = [image for images_list in images for image in images_list]
|
||||
num_patches = reorder_images(num_patches, grouped_images_index)
|
||||
else:
|
||||
num_patches = [1] * len(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GotOcr2ImageProcessorFast"]
|
@ -32,11 +32,7 @@ from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ..auto import AutoModelForCausalLM
|
||||
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
|
||||
|
||||
|
@ -14,35 +14,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.blip.image_processing_blip import BlipImageProcessor
|
||||
from transformers.models.llava.modeling_llava import (
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaPreTrainedModel,
|
||||
)
|
||||
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
|
||||
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
||||
from transformers.tokenization_utils_base import (
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
)
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...image_processing_utils import BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
_rescale_for_pil_conversion,
|
||||
to_channel_dimension_format,
|
||||
to_pil_image,
|
||||
)
|
||||
from ...image_utils import ChannelDimension, ImageInput
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_vision_available,
|
||||
@ -53,9 +38,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
from ...image_utils import load_images
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -246,437 +229,6 @@ class GotOcr2Config(PretrainedConfig):
|
||||
__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"]
|
||||
|
||||
|
||||
class GotOcr2TextKwargs(TextKwargs, total=False):
|
||||
format: Optional[bool]
|
||||
|
||||
|
||||
class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
|
||||
box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]]
|
||||
color: Optional[str]
|
||||
num_image_tokens: Optional[int]
|
||||
multi_page: Optional[bool]
|
||||
crop_to_patches: Optional[bool]
|
||||
min_patches: Optional[int]
|
||||
max_patches: Optional[int]
|
||||
|
||||
|
||||
class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: GotOcr2TextKwargs
|
||||
images_kwargs: GotOcr2ImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"format": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"num_image_tokens": 256,
|
||||
"multi_page": False,
|
||||
"crop_to_patches": False,
|
||||
"min_patches": 1,
|
||||
"max_patches": 12,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def preprocess_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List:
|
||||
"""
|
||||
Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000].
|
||||
"""
|
||||
width, height = image_size
|
||||
if len(box) == 4:
|
||||
box[0] = int(box[0] / width * 1000)
|
||||
box[1] = int(box[1] / height * 1000)
|
||||
box[2] = int(box[2] / width * 1000)
|
||||
box[3] = int(box[3] / height * 1000)
|
||||
else:
|
||||
raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].")
|
||||
|
||||
return list(box)
|
||||
|
||||
|
||||
# Similar to image_processing_mllama.get_all_supported_aspect_ratios
|
||||
@lru_cache(maxsize=10)
|
||||
def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
|
||||
|
||||
This function calculates all possible arrangements of tiles that can be formed
|
||||
within the constraint of the minimum and maximum number of tiles. Each arrangement is
|
||||
represented by its aspect ratio (width/height) and the corresponding tile configuration.
|
||||
|
||||
Args:
|
||||
min_image_tiles (`int`):
|
||||
The minimum number of tiles allowed.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles allowed.
|
||||
|
||||
Returns:
|
||||
`List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
|
||||
configuration in terms of number of tiles.
|
||||
|
||||
Example:
|
||||
>>> get_all_supported_aspect_ratios(1, 4)
|
||||
[(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
|
||||
|
||||
"""
|
||||
aspect_ratios = []
|
||||
for width in range(1, max_image_tiles + 1):
|
||||
for height in range(1, max_image_tiles + 1):
|
||||
if width * height <= max_image_tiles and width * height >= min_image_tiles:
|
||||
aspect_ratios.append((width, height))
|
||||
|
||||
aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
return aspect_ratios
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_optimal_tiled_canvas(
|
||||
original_image_size: Tuple[int, int],
|
||||
target_tile_size: Tuple[int, int],
|
||||
min_image_tiles: int,
|
||||
max_image_tiles: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
|
||||
original image aspect ratio.
|
||||
In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
|
||||
more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
|
||||
excessive tiling.
|
||||
"""
|
||||
possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
|
||||
|
||||
original_height, original_width = original_image_size
|
||||
target_tile_height, target_tile_width = target_tile_size
|
||||
aspect_ratio = original_width / original_height
|
||||
area = original_width * original_height
|
||||
|
||||
# find the grid with the best aspect ratio
|
||||
best_ratio_diff = float("inf")
|
||||
best_grid = (1, 1)
|
||||
for grid in possible_tile_arrangements:
|
||||
grid_aspect_ratio = grid[0] / grid[1]
|
||||
ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_grid = grid
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
# if the aspect ratio difference is the same, we favor the grid with more patches
|
||||
# until the area covered by the patches is more than twice the original image area
|
||||
if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
|
||||
best_grid = grid
|
||||
|
||||
return best_grid
|
||||
|
||||
|
||||
class GotOcr2ImageProcessor(BlipImageProcessor):
|
||||
def crop_image_to_patches(
|
||||
self,
|
||||
image: ImageInput,
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_thumbnail: bool = True,
|
||||
patch_size: Union[Tuple, int, dict] = None,
|
||||
return_numpy: bool = False,
|
||||
data_format: ChannelDimension = None,
|
||||
):
|
||||
"""
|
||||
Crop the image to patches and return a list of cropped images.
|
||||
The number of patches and their grid arrangement are determined by the original image size,
|
||||
the target patch size and the minimum and maximum number of patches.
|
||||
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
|
||||
The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor.
|
||||
min_patches (`int`):
|
||||
The minimum number of patches to be extracted from the image.
|
||||
max_patches (`int`):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
use_thumbnail (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a thumbnail image to the list of cropped patches.
|
||||
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
|
||||
The size of the output patches.
|
||||
return_numpy (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the cropped images as NumPy arrays.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The format of the image data. If `None`, the format is inferred from the input image.
|
||||
|
||||
Returns:
|
||||
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
|
||||
"""
|
||||
patch_size = patch_size if patch_size is not None else self.size
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
original_size = get_size_dict(image.size, height_width_order=False)
|
||||
do_rescale = False
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
do_rescale = _rescale_for_pil_conversion(image)
|
||||
image = to_pil_image(image, do_rescale=do_rescale)
|
||||
|
||||
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
|
||||
original_height, original_width = original_size["height"], original_size["width"]
|
||||
# find the closest aspect ratio to the target
|
||||
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = patch_size_width * num_columns
|
||||
target_height = patch_size_height * num_rows
|
||||
num_blocks = num_columns * num_rows
|
||||
|
||||
# resize the image so that each patch is of patch_size
|
||||
resized_image = image.resize((target_width, target_height))
|
||||
|
||||
# split the image into patches
|
||||
processed_images = []
|
||||
for i in range(num_blocks):
|
||||
column = i % num_columns
|
||||
row = i // num_columns
|
||||
box = (
|
||||
column * patch_size_width,
|
||||
row * patch_size_height,
|
||||
(column + 1) * patch_size_width,
|
||||
(row + 1) * patch_size_height,
|
||||
)
|
||||
# split the image
|
||||
patch_image = resized_image.crop(box)
|
||||
processed_images.append(patch_image)
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((patch_size_width, patch_size_height))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
if return_numpy:
|
||||
processed_images_numpy = []
|
||||
for processed_image in processed_images:
|
||||
processed_image = np.array(processed_image)
|
||||
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
||||
# so we need to add it back if necessary.
|
||||
processed_image = (
|
||||
np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
|
||||
)
|
||||
# The image is always in channels last format after converting from a PIL image
|
||||
if data_format is not None:
|
||||
processed_image = to_channel_dimension_format(
|
||||
processed_image, data_format, input_channel_dim=ChannelDimension.LAST
|
||||
)
|
||||
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
||||
# rescale it back to the original range.
|
||||
processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
|
||||
processed_images_numpy.append(processed_image)
|
||||
processed_images = processed_images_numpy
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
class GotOcr2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and
|
||||
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
|
||||
tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`GotOcr2ImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
image_processor_class = "GotOcr2ImageProcessor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
self.message_start_token = "<|im_start|>"
|
||||
self.message_end_token = "<|im_end|>"
|
||||
self.img_start_token = "<img>"
|
||||
self.img_end_token = "</img>"
|
||||
self.img_pad_token = "<imgpad>"
|
||||
self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail."
|
||||
|
||||
def _make_list_of_inputs(self, images, text, box, color, multi_page):
|
||||
if not isinstance(images, (list, tuple)):
|
||||
images = [images]
|
||||
if multi_page:
|
||||
logger.warning("Multi-page inference is enabled but only one image is passed.")
|
||||
images = [images]
|
||||
elif isinstance(images[0], (list, tuple)) and not multi_page:
|
||||
raise ValueError("Nested images are only supported with `multi_page` set to `True`.")
|
||||
elif not isinstance(images[0], (list, tuple)) and multi_page:
|
||||
images = [images]
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
if not isinstance(box[0], (list, tuple)):
|
||||
# Use the same box for all images
|
||||
box = [box for _ in range(len(images))]
|
||||
if not isinstance(color, (list, tuple)):
|
||||
color = [color for _ in range(len(images))]
|
||||
|
||||
return images, text, box, color
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[ImageInput] = None,
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[GotOcr2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
|
||||
is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
|
||||
`crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
|
||||
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
format (`bool`, *optional*):
|
||||
If set, will add the format token to the query, and the model will return the OCR result with formatting.
|
||||
box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*):
|
||||
The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it
|
||||
will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the
|
||||
form (x1, y1, x2, y2).
|
||||
color (`str`, *optional*):
|
||||
The color annotation to be added to the query. The model will return the OCR result within the box with
|
||||
the specified color.
|
||||
multi_page (`bool`, *optional*):
|
||||
If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
|
||||
crop_to_patches (`bool`, *optional*):
|
||||
If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
|
||||
min_patches (`int`, *optional*):
|
||||
The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
|
||||
`True`.
|
||||
max_patches (`int`, *optional*):
|
||||
The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
|
||||
`True`.
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
GotOcr2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
format_output = output_kwargs["text_kwargs"].pop("format")
|
||||
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens")
|
||||
box = output_kwargs["images_kwargs"].pop("box", [None])
|
||||
color = output_kwargs["images_kwargs"].pop("color", None)
|
||||
multi_page = output_kwargs["images_kwargs"].pop("multi_page")
|
||||
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
|
||||
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
|
||||
max_patches = output_kwargs["images_kwargs"].pop("max_patches")
|
||||
|
||||
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
|
||||
|
||||
# Load images as we need to know the image size
|
||||
images = load_images(images)
|
||||
if text is None:
|
||||
text = []
|
||||
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
|
||||
if crop_to_patches:
|
||||
image_group = self.image_processor.crop_image_to_patches(
|
||||
image_group,
|
||||
patch_size=output_kwargs["images_kwargs"].get("size"),
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
)
|
||||
images[index] = image_group
|
||||
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
|
||||
if box_single[0] is not None:
|
||||
box_single = preprocess_box_annotation(box_single, image_group.size)
|
||||
query = (
|
||||
f"{f'[{color_single}] ' if color_single is not None else ''}"
|
||||
f"{str(box_single) if box_single[0] is not None else ''} "
|
||||
"OCR"
|
||||
f"{' with format' if format_output else ''}"
|
||||
f"{' across multi pages' if multi_page else ''}"
|
||||
f"{' upon the patch reference' if crop_to_patches else ''}"
|
||||
": "
|
||||
)
|
||||
prompt = (
|
||||
self.message_start_token
|
||||
+ self.system_query
|
||||
+ self.message_end_token
|
||||
+ self.message_start_token
|
||||
+ "user\n"
|
||||
+ self.img_start_token
|
||||
+ self.img_pad_token * num_image_tokens * num_images
|
||||
+ self.img_end_token
|
||||
+ "\n"
|
||||
+ query
|
||||
+ self.message_end_token
|
||||
+ self.message_start_token
|
||||
+ "assistant\n"
|
||||
)
|
||||
text.append(prompt)
|
||||
elif crop_to_patches:
|
||||
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
|
||||
image_group = self.image_processor.crop_image_to_patches(
|
||||
image_group,
|
||||
patch_size=output_kwargs["images_kwargs"].get("size"),
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
)
|
||||
images[index] = image_group
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
if multi_page or crop_to_patches:
|
||||
# flatten images
|
||||
images = [image for image_group in images for image in image_group]
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||
|
||||
|
||||
class GotOcr2MLPBlock(SamMLPBlock):
|
||||
pass
|
||||
|
||||
@ -972,8 +524,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
__all__ = [
|
||||
"GotOcr2VisionConfig",
|
||||
"GotOcr2Config",
|
||||
"GotOcr2Processor",
|
||||
"GotOcr2PreTrainedModel",
|
||||
"GotOcr2ForConditionalGeneration",
|
||||
"GotOcr2ImageProcessor",
|
||||
]
|
||||
|
@ -1,9 +1,3 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_got_ocr2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -22,6 +16,8 @@
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
@ -100,7 +96,7 @@ class GotOcr2Processor(ProcessorMixin):
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
image_processor_class = "GotOcr2ImageProcessor"
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
@ -205,28 +201,29 @@ class GotOcr2Processor(ProcessorMixin):
|
||||
box = output_kwargs["images_kwargs"].pop("box", [None])
|
||||
color = output_kwargs["images_kwargs"].pop("color", None)
|
||||
multi_page = output_kwargs["images_kwargs"].pop("multi_page")
|
||||
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
|
||||
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
|
||||
max_patches = output_kwargs["images_kwargs"].pop("max_patches")
|
||||
|
||||
crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches")
|
||||
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
|
||||
|
||||
if multi_page:
|
||||
# save the number of pages per batch
|
||||
num_pages_per_batch = [len(image_group) for image_group in images]
|
||||
# flatten the list of images
|
||||
images = [image for image_group in images for image in image_group]
|
||||
else:
|
||||
num_pages_per_batch = [1 for _ in range(len(images))]
|
||||
# Load images as we need to know the image size
|
||||
images = load_images(images)
|
||||
image_sizes = [image.size for image in images]
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
num_patches_array = image_inputs.pop("num_patches")
|
||||
if text is None:
|
||||
text = []
|
||||
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
|
||||
if crop_to_patches:
|
||||
image_group = self.image_processor.crop_image_to_patches(
|
||||
image_group,
|
||||
patch_size=output_kwargs["images_kwargs"].get("size"),
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
)
|
||||
images[index] = image_group
|
||||
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
|
||||
patch_indices = np.cumsum(num_pages_per_batch)
|
||||
for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)):
|
||||
current_patch_index = patch_indices[index - 1] if index > 0 else 0
|
||||
num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages])
|
||||
if box_single[0] is not None:
|
||||
box_single = preprocess_box_annotation(box_single, image_group.size)
|
||||
box_single = preprocess_box_annotation(box_single, image_sizes[index])
|
||||
query = (
|
||||
f"{f'[{color_single}] ' if color_single is not None else ''}"
|
||||
f"{str(box_single) if box_single[0] is not None else ''} "
|
||||
@ -243,7 +240,7 @@ class GotOcr2Processor(ProcessorMixin):
|
||||
+ self.message_start_token
|
||||
+ "user\n"
|
||||
+ self.img_start_token
|
||||
+ self.img_pad_token * num_image_tokens * num_images
|
||||
+ self.img_pad_token * num_image_tokens * num_patches
|
||||
+ self.img_end_token
|
||||
+ "\n"
|
||||
+ query
|
||||
@ -252,22 +249,8 @@ class GotOcr2Processor(ProcessorMixin):
|
||||
+ "assistant\n"
|
||||
)
|
||||
text.append(prompt)
|
||||
elif crop_to_patches:
|
||||
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
|
||||
image_group = self.image_processor.crop_image_to_patches(
|
||||
image_group,
|
||||
patch_size=output_kwargs["images_kwargs"].get("size"),
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
)
|
||||
images[index] = image_group
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
if multi_page or crop_to_patches:
|
||||
# flatten images
|
||||
images = [image for image_group in images for image in image_group]
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class GotOcr2ImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class LlavaImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
@ -16,15 +16,22 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import GotOcr2ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import GotOcr2ImageProcessorFast
|
||||
|
||||
|
||||
class GotOcr2ImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
@ -89,6 +96,7 @@ class GotOcr2ImageProcessingTester(unittest.TestCase):
|
||||
@require_vision
|
||||
class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = GotOcr2ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -99,7 +107,8 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processor, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processor, "size"))
|
||||
self.assertTrue(hasattr(image_processor, "do_normalize"))
|
||||
@ -107,9 +116,63 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertTrue(hasattr(image_processor, "image_std"))
|
||||
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
|
||||
|
||||
def test_slow_fast_equivalence_crop_to_patches(self):
|
||||
dummy_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)[0]
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
def test_slow_fast_equivalence_batched_crop_to_patches(self):
|
||||
# Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
|
||||
# different resolutions in between
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
def test_crop_to_patches(self):
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
|
||||
processed_images = image_processor.crop_image_to_patches(image, 1, 6, use_thumbnail=True)
|
||||
# test slow image processor
|
||||
image_processor = self.image_processor_list[0](**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)[0]
|
||||
processed_images = image_processor.crop_image_to_patches(
|
||||
image,
|
||||
min_patches=1,
|
||||
max_patches=6,
|
||||
use_thumbnail=True,
|
||||
patch_size={"height": 20, "width": 20},
|
||||
)
|
||||
self.assertEqual(len(processed_images), 5)
|
||||
self.assertEqual(processed_images[0].size, (20, 20))
|
||||
self.assertEqual(processed_images[0].shape[:2], (20, 20))
|
||||
|
||||
# test fast image processor (process batch)
|
||||
image_processor = self.image_processor_list[1](**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)[0]
|
||||
processed_images = image_processor.crop_image_to_patches(
|
||||
image.unsqueeze(0),
|
||||
min_patches=1,
|
||||
max_patches=6,
|
||||
use_thumbnail=True,
|
||||
patch_size=SizeDict(height=20, width=20),
|
||||
)
|
||||
self.assertEqual(len(processed_images[0]), 5)
|
||||
self.assertEqual(processed_images.shape[-2:], (20, 20))
|
||||
|
Loading…
Reference in New Issue
Block a user