Add Got-OCR 2 Fast image processor and refactor slow one (#36185)

* refactor image processor slow got ocr

* add working image processor fast

* fix fast image processor, update doc

* use one big loop for processing patches
This commit is contained in:
Yoni Gozlan 2025-03-01 00:56:00 -05:00 committed by GitHub
parent 51083d1bac
commit 2c5d038f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 467 additions and 584 deletions

View File

@ -44,13 +44,14 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
>>> inputs = processor(image, return_tensors="pt").to(device) >>> inputs = processor(image, return_tensors="pt", device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -68,15 +69,16 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" >>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" >>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
>>> inputs = processor([image1, image2], return_tensors="pt").to(device) >>> inputs = processor([image1, image2], return_tensors="pt", device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -96,13 +98,14 @@ GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png" >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to(device) >>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -124,14 +127,15 @@ Here is an example of how to process multiple pages at once:
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png" >>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png" >>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device) >>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True, device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -153,13 +157,14 @@ Here is an example of how to process cropped patches:
```python ```python
>>> import torch >>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device) >>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3, device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -179,13 +184,14 @@ GOT supports interactive OCR, where the user can specify the region to be recogn
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels) >>> inputs = processor(image, return_tensors="pt", color="green", device=device).to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -206,14 +212,15 @@ Here is an example of how to process sheet music:
```python ```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> import verovio >>> import verovio
>>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png" >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to(device) >>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
>>> generate_ids = model.generate( >>> generate_ids = model.generate(
... **inputs, ... **inputs,
@ -258,6 +265,10 @@ alt="drawing" width="600"/>
[[autodoc]] GotOcr2ImageProcessor [[autodoc]] GotOcr2ImageProcessor
## GotOcr2ImageProcessorFast
[[autodoc]] GotOcr2ImageProcessorFast
## GotOcr2Processor ## GotOcr2Processor
[[autodoc]] GotOcr2Processor [[autodoc]] GotOcr2Processor

View File

@ -1330,6 +1330,7 @@ else:
_import_structure["models.deit"].append("DeiTImageProcessorFast") _import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast") _import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast") _import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@ -6526,6 +6527,7 @@ if TYPE_CHECKING:
from .models.deit import DeiTImageProcessorFast from .models.deit import DeiTImageProcessorFast
from .models.depth_pro import DepthProImageProcessorFast from .models.depth_pro import DepthProImageProcessorFast
from .models.detr import DetrImageProcessorFast from .models.detr import DetrImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llava import LlavaImageProcessorFast from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast from .models.llava_onevision import LlavaOnevisionImageProcessorFast

View File

@ -88,7 +88,7 @@ else:
("fuyu", ("FuyuImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)), ("glpn", ("GLPNImageProcessor",)),
("got_ocr2", ("GotOcr2ImageProcessor",)), ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("grounding-dino", ("GroundingDinoImageProcessor",)), ("grounding-dino", ("GroundingDinoImageProcessor",)),
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("hiera", ("BitImageProcessor",)), ("hiera", ("BitImageProcessor",)),

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_got_ocr2 import * from .configuration_got_ocr2 import *
from .image_processing_got_ocr2 import * from .image_processing_got_ocr2 import *
from .image_processing_got_ocr2_fast import *
from .modeling_got_ocr2 import * from .modeling_got_ocr2 import *
from .processing_got_ocr2 import * from .processing_got_ocr2 import *

View File

@ -1,9 +1,3 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_got_ocr2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #
@ -18,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Image processor class for Got-OCR-2."""
from functools import lru_cache from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@ -27,11 +21,9 @@ import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
_rescale_for_pil_conversion,
convert_to_rgb, convert_to_rgb,
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
to_pil_image,
) )
from ...image_utils import ( from ...image_utils import (
OPENAI_CLIP_MEAN, OPENAI_CLIP_MEAN,
@ -142,6 +134,15 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method. method.
crop_to_patches (`bool`, *optional*, defaults to `False`):
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
`preprocess` method.
min_patches (`int`, *optional*, defaults to 1):
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
max_patches (`int`, *optional*, defaults to 12):
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method. overridden by the `resample` parameter in the `preprocess` method.
@ -172,6 +173,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
self, self,
do_resize: bool = True, do_resize: bool = True,
size: Dict[str, int] = None, size: Dict[str, int] = None,
crop_to_patches: bool = False,
min_patches: int = 1,
max_patches: int = 12,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True, do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255, rescale_factor: Union[int, float] = 1 / 255,
@ -187,6 +191,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
self.crop_to_patches = crop_to_patches
self.min_patches = min_patches
self.max_patches = max_patches
self.resample = resample self.resample = resample
self.do_rescale = do_rescale self.do_rescale = do_rescale
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
@ -249,6 +256,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
images: ImageInput, images: ImageInput,
do_resize: Optional[bool] = None, do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None, size: Optional[Dict[str, int]] = None,
crop_to_patches: Optional[bool] = None,
min_patches: Optional[int] = None,
max_patches: Optional[int] = None,
resample: PILImageResampling = None, resample: PILImageResampling = None,
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None, rescale_factor: Optional[float] = None,
@ -274,6 +284,14 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`. edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
Whether to crop the image to patches.
min_patches (`int`, *optional*, defaults to `self.min_patches`):
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`.
max_patches (`int`, *optional*, defaults to `self.max_patches`):
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
@ -308,6 +326,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
min_patches = min_patches if min_patches is not None else self.min_patches
max_patches = max_patches if max_patches is not None else self.max_patches
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
@ -353,40 +374,52 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
# We assume that all images have the same channel dimension format. # We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0]) input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if crop_to_patches and max_patches > 1:
images = [ images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) self.crop_image_to_patches(
image,
min_patches=min_patches,
max_patches=max_patches,
patch_size=size,
data_format=input_data_format,
)
for image in images for image in images
] ]
num_patches = np.array([len(image) for image in images])
images = [image for images_list in images for image in images_list]
else:
num_patches = np.array([1] * len(images))
if do_rescale: for i, image in enumerate(images):
images = [ if do_resize:
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_rescale:
images = [ images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [ if do_normalize:
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images images[i] = self.normalize(
] image=images[i],
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
encoded_outputs = BatchFeature(
data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors
)
return encoded_outputs return encoded_outputs
def crop_image_to_patches( def crop_image_to_patches(
self, self,
image: ImageInput, images: np.ndarray,
min_patches: int, min_patches: int,
max_patches: int, max_patches: int,
use_thumbnail: bool = True, use_thumbnail: bool = True,
patch_size: Union[Tuple, int, dict] = None, patch_size: Union[Tuple, int, dict] = None,
return_numpy: bool = False,
data_format: ChannelDimension = None, data_format: ChannelDimension = None,
): ):
""" """
@ -396,8 +429,8 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
Args: Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): images (`np.ndarray`):
The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor. The image to be cropped.
min_patches (`int`): min_patches (`int`):
The minimum number of patches to be extracted from the image. The minimum number of patches to be extracted from the image.
max_patches (`int`): max_patches (`int`):
@ -406,24 +439,17 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
Whether to add a thumbnail image to the list of cropped patches. Whether to add a thumbnail image to the list of cropped patches.
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*): patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
The size of the output patches. The size of the output patches.
return_numpy (`bool`, *optional*, defaults to `False`):
Whether to return the cropped images as NumPy arrays.
data_format (`ChannelDimension`, *optional*): data_format (`ChannelDimension`, *optional*):
The format of the image data. If `None`, the format is inferred from the input image. The format of the image data. If `None`, the format is inferred from the input image.
Returns: Returns:
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images. List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
""" """
patch_size = patch_size if patch_size is not None else self.size if data_format is None:
patch_size = get_size_dict(patch_size, default_to_square=True) data_format = infer_channel_dimension_format(images)
original_size = get_size_dict(image.size, height_width_order=False) images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
do_rescale = False
if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale)
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"] patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
original_height, original_width = original_size["height"], original_size["width"] original_height, original_width = images.shape[-2:]
# find the closest aspect ratio to the target # find the closest aspect ratio to the target
num_columns, num_rows = get_optimal_tiled_canvas( num_columns, num_rows = get_optimal_tiled_canvas(
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
@ -435,8 +461,12 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
num_blocks = num_columns * num_rows num_blocks = num_columns * num_rows
# resize the image so that each patch is of patch_size # resize the image so that each patch is of patch_size
resized_image = image.resize((target_width, target_height)) resized_image = self.resize(
images,
{"height": target_height, "width": target_width},
data_format=ChannelDimension.FIRST,
input_data_format=ChannelDimension.FIRST,
)
# split the image into patches # split the image into patches
processed_images = [] processed_images = []
for i in range(num_blocks): for i in range(num_blocks):
@ -449,33 +479,16 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
(row + 1) * patch_size_height, (row + 1) * patch_size_height,
) )
# split the image # split the image
patch_image = resized_image.crop(box) patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
processed_images.append(patch_image) processed_images.append(patch_image)
if use_thumbnail and len(processed_images) != 1: if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((patch_size_width, patch_size_height)) thumbnail_img = self.resize(
images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
)
processed_images.append(thumbnail_img) processed_images.append(thumbnail_img)
if return_numpy:
processed_images_numpy = []
for processed_image in processed_images:
processed_image = np.array(processed_image)
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary.
processed_image = (
np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
)
# The image is always in channels last format after converting from a PIL image
if data_format is not None:
processed_image = to_channel_dimension_format(
processed_image, data_format, input_channel_dim=ChannelDimension.LAST
)
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
# rescale it back to the original range.
processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
processed_images_numpy.append(processed_image)
processed_images = processed_images_numpy
return processed_images return processed_images

View File

@ -0,0 +1,257 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Got-OCR-2."""
from typing import List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
PILImageResampling,
SizeDict,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
from .image_processing_got_ocr2 import get_optimal_tiled_canvas
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class GotOcr2ImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]
class GotOcr2ImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]
@add_start_docstrings(
"Constructs a fast GotOcr2 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
crop_to_patches (`bool`, *optional*, defaults to `False`):
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
`preprocess` method.
min_patches (`int`, *optional*, defaults to 1):
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
max_patches (`int`, *optional*, defaults to 12):
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
""",
)
class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 384, "width": 384}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
crop_to_patches = False
min_patches = 1
max_patches = 12
valid_init_kwargs = GotOcr2ImageProcessorInitKwargs
valid_preprocess_kwargs = GotOcr2ImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[GotOcr2ImageProcessorInitKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
crop_to_patches (`bool`, *optional*, defaults to `False`):
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
`preprocess` method.
min_patches (`int`, *optional*, defaults to 1):
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
max_patches (`int`, *optional*, defaults to 12):
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2ImageProcessorPreprocessKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def crop_image_to_patches(
self,
images: "torch.Tensor",
min_patches: int,
max_patches: int,
use_thumbnail: bool = True,
patch_size: Union[Tuple, int, dict] = None,
interpolation: Optional["F.InterpolationMode"] = None,
):
"""
Crop the images to patches and return a list of cropped images.
The number of patches and their grid arrangement are determined by the original image size,
the target patch size and the minimum and maximum number of patches.
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
Args:
images (`torch.Tensor`):
The images to be cropped.
min_patches (`int`):
The minimum number of patches to be extracted from the image.
max_patches (`int`):
The maximum number of patches to be extracted from the image.
use_thumbnail (`bool`, *optional*, defaults to `True`):
Whether to add a thumbnail image to the list of cropped patches.
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
The size of the output patches.
The format of the image data. If `None`, the format is inferred from the input image.
Returns:
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
"""
patch_size_height, patch_size_width = patch_size.height, patch_size.width
original_height, original_width = images.shape[-2:]
# find the closest aspect ratio to the target
num_columns, num_rows = get_optimal_tiled_canvas(
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
)
# calculate the target width and height
target_width = patch_size_width * num_columns
target_height = patch_size_height * num_rows
num_blocks = num_columns * num_rows
# resize the image so that each patch is of patch_size
resized_image = self.resize(
images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
)
# split the image into patches
processed_images = []
for i in range(num_blocks):
column = i % num_columns
row = i // num_columns
box = (
column * patch_size_width,
row * patch_size_height,
(column + 1) * patch_size_width,
(row + 1) * patch_size_height,
)
# split the image
patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
processed_images.append(patch_image)
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
processed_images.append(thumbnail_img)
processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
return processed_images
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
crop_to_patches: bool,
min_patches: int,
max_patches: int,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
if crop_to_patches:
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_images_grouped = {}
num_patches = {}
for shape, stacked_images in grouped_images.items():
stacked_images = self.crop_image_to_patches(
stacked_images,
min_patches,
max_patches,
patch_size=size,
interpolation=interpolation,
)
processed_images_grouped[shape] = stacked_images
num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0]
images = reorder_images(processed_images_grouped, grouped_images_index)
images = [image for images_list in images for image in images_list]
num_patches = reorder_images(num_patches, grouped_images_index)
else:
num_patches = [1] * len(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
)
__all__ = ["GotOcr2ImageProcessorFast"]

View File

@ -32,11 +32,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ..auto import AutoModelForCausalLM from ..auto import AutoModelForCausalLM
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig

View File

@ -14,35 +14,20 @@
# limitations under the License. # limitations under the License.
from functools import lru_cache
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers.models.blip.image_processing_blip import BlipImageProcessor
from transformers.models.llava.modeling_llava import ( from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast, LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
LlavaPreTrainedModel, LlavaPreTrainedModel,
) )
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from transformers.tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_transforms import (
_rescale_for_pil_conversion,
to_channel_dimension_format,
to_pil_image,
)
from ...image_utils import ChannelDimension, ImageInput
from ...utils import ( from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_vision_available, is_vision_available,
@ -53,9 +38,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
if is_vision_available(): if is_vision_available():
import PIL pass
from ...image_utils import load_images
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -246,437 +229,6 @@ class GotOcr2Config(PretrainedConfig):
__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"] __all__ = ["GotOcr2VisionConfig", "GotOcr2Config"]
class GotOcr2TextKwargs(TextKwargs, total=False):
format: Optional[bool]
class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]]
color: Optional[str]
num_image_tokens: Optional[int]
multi_page: Optional[bool]
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]
class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: GotOcr2TextKwargs
images_kwargs: GotOcr2ImagesKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"format": False,
},
"images_kwargs": {
"num_image_tokens": 256,
"multi_page": False,
"crop_to_patches": False,
"min_patches": 1,
"max_patches": 12,
},
}
def preprocess_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List:
"""
Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000].
"""
width, height = image_size
if len(box) == 4:
box[0] = int(box[0] / width * 1000)
box[1] = int(box[1] / height * 1000)
box[2] = int(box[2] / width * 1000)
box[3] = int(box[3] / height * 1000)
else:
raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].")
return list(box)
# Similar to image_processing_mllama.get_all_supported_aspect_ratios
@lru_cache(maxsize=10)
def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> List[Tuple[int, int]]:
"""
Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
This function calculates all possible arrangements of tiles that can be formed
within the constraint of the minimum and maximum number of tiles. Each arrangement is
represented by its aspect ratio (width/height) and the corresponding tile configuration.
Args:
min_image_tiles (`int`):
The minimum number of tiles allowed.
max_image_tiles (`int`):
The maximum number of tiles allowed.
Returns:
`List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
configuration in terms of number of tiles.
Example:
>>> get_all_supported_aspect_ratios(1, 4)
[(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
"""
aspect_ratios = []
for width in range(1, max_image_tiles + 1):
for height in range(1, max_image_tiles + 1):
if width * height <= max_image_tiles and width * height >= min_image_tiles:
aspect_ratios.append((width, height))
aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
return aspect_ratios
@lru_cache(maxsize=100)
def get_optimal_tiled_canvas(
original_image_size: Tuple[int, int],
target_tile_size: Tuple[int, int],
min_image_tiles: int,
max_image_tiles: int,
) -> Tuple[int, int]:
"""
Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
original image aspect ratio.
In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
excessive tiling.
"""
possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
original_height, original_width = original_image_size
target_tile_height, target_tile_width = target_tile_size
aspect_ratio = original_width / original_height
area = original_width * original_height
# find the grid with the best aspect ratio
best_ratio_diff = float("inf")
best_grid = (1, 1)
for grid in possible_tile_arrangements:
grid_aspect_ratio = grid[0] / grid[1]
ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_grid = grid
elif ratio_diff == best_ratio_diff:
# if the aspect ratio difference is the same, we favor the grid with more patches
# until the area covered by the patches is more than twice the original image area
if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
best_grid = grid
return best_grid
class GotOcr2ImageProcessor(BlipImageProcessor):
def crop_image_to_patches(
self,
image: ImageInput,
min_patches: int,
max_patches: int,
use_thumbnail: bool = True,
patch_size: Union[Tuple, int, dict] = None,
return_numpy: bool = False,
data_format: ChannelDimension = None,
):
"""
Crop the image to patches and return a list of cropped images.
The number of patches and their grid arrangement are determined by the original image size,
the target patch size and the minimum and maximum number of patches.
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor.
min_patches (`int`):
The minimum number of patches to be extracted from the image.
max_patches (`int`):
The maximum number of patches to be extracted from the image.
use_thumbnail (`bool`, *optional*, defaults to `True`):
Whether to add a thumbnail image to the list of cropped patches.
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
The size of the output patches.
return_numpy (`bool`, *optional*, defaults to `False`):
Whether to return the cropped images as NumPy arrays.
data_format (`ChannelDimension`, *optional*):
The format of the image data. If `None`, the format is inferred from the input image.
Returns:
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
"""
patch_size = patch_size if patch_size is not None else self.size
patch_size = get_size_dict(patch_size, default_to_square=True)
original_size = get_size_dict(image.size, height_width_order=False)
do_rescale = False
if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale)
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
original_height, original_width = original_size["height"], original_size["width"]
# find the closest aspect ratio to the target
num_columns, num_rows = get_optimal_tiled_canvas(
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
)
# calculate the target width and height
target_width = patch_size_width * num_columns
target_height = patch_size_height * num_rows
num_blocks = num_columns * num_rows
# resize the image so that each patch is of patch_size
resized_image = image.resize((target_width, target_height))
# split the image into patches
processed_images = []
for i in range(num_blocks):
column = i % num_columns
row = i // num_columns
box = (
column * patch_size_width,
row * patch_size_height,
(column + 1) * patch_size_width,
(row + 1) * patch_size_height,
)
# split the image
patch_image = resized_image.crop(box)
processed_images.append(patch_image)
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((patch_size_width, patch_size_height))
processed_images.append(thumbnail_img)
if return_numpy:
processed_images_numpy = []
for processed_image in processed_images:
processed_image = np.array(processed_image)
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary.
processed_image = (
np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
)
# The image is always in channels last format after converting from a PIL image
if data_format is not None:
processed_image = to_channel_dimension_format(
processed_image, data_format, input_channel_dim=ChannelDimension.LAST
)
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
# rescale it back to the original range.
processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
processed_images_numpy.append(processed_image)
processed_images = processed_images_numpy
return processed_images
class GotOcr2Processor(ProcessorMixin):
r"""
Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information.
Args:
image_processor ([`GotOcr2ImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "GotOcr2ImageProcessor"
tokenizer_class = "PreTrainedTokenizerFast"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.message_start_token = "<|im_start|>"
self.message_end_token = "<|im_end|>"
self.img_start_token = "<img>"
self.img_end_token = "</img>"
self.img_pad_token = "<imgpad>"
self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail."
def _make_list_of_inputs(self, images, text, box, color, multi_page):
if not isinstance(images, (list, tuple)):
images = [images]
if multi_page:
logger.warning("Multi-page inference is enabled but only one image is passed.")
images = [images]
elif isinstance(images[0], (list, tuple)) and not multi_page:
raise ValueError("Nested images are only supported with `multi_page` set to `True`.")
elif not isinstance(images[0], (list, tuple)) and multi_page:
images = [images]
if isinstance(text, str):
text = [text]
if not isinstance(box[0], (list, tuple)):
# Use the same box for all images
box = [box for _ in range(len(images))]
if not isinstance(color, (list, tuple)):
color = [color for _ in range(len(images))]
return images, text, box, color
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[GotOcr2ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
`crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
format (`bool`, *optional*):
If set, will add the format token to the query, and the model will return the OCR result with formatting.
box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*):
The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it
will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the
form (x1, y1, x2, y2).
color (`str`, *optional*):
The color annotation to be added to the query. The model will return the OCR result within the box with
the specified color.
multi_page (`bool`, *optional*):
If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
crop_to_patches (`bool`, *optional*):
If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
min_patches (`int`, *optional*):
The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
max_patches (`int`, *optional*):
The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
GotOcr2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
format_output = output_kwargs["text_kwargs"].pop("format")
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens")
box = output_kwargs["images_kwargs"].pop("box", [None])
color = output_kwargs["images_kwargs"].pop("color", None)
multi_page = output_kwargs["images_kwargs"].pop("multi_page")
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
max_patches = output_kwargs["images_kwargs"].pop("max_patches")
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
# Load images as we need to know the image size
images = load_images(images)
if text is None:
text = []
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
if crop_to_patches:
image_group = self.image_processor.crop_image_to_patches(
image_group,
patch_size=output_kwargs["images_kwargs"].get("size"),
min_patches=min_patches,
max_patches=max_patches,
)
images[index] = image_group
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
if box_single[0] is not None:
box_single = preprocess_box_annotation(box_single, image_group.size)
query = (
f"{f'[{color_single}] ' if color_single is not None else ''}"
f"{str(box_single) if box_single[0] is not None else ''} "
"OCR"
f"{' with format' if format_output else ''}"
f"{' across multi pages' if multi_page else ''}"
f"{' upon the patch reference' if crop_to_patches else ''}"
": "
)
prompt = (
self.message_start_token
+ self.system_query
+ self.message_end_token
+ self.message_start_token
+ "user\n"
+ self.img_start_token
+ self.img_pad_token * num_image_tokens * num_images
+ self.img_end_token
+ "\n"
+ query
+ self.message_end_token
+ self.message_start_token
+ "assistant\n"
)
text.append(prompt)
elif crop_to_patches:
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
image_group = self.image_processor.crop_image_to_patches(
image_group,
patch_size=output_kwargs["images_kwargs"].get("size"),
min_patches=min_patches,
max_patches=max_patches,
)
images[index] = image_group
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
if multi_page or crop_to_patches:
# flatten images
images = [image for image_group in images for image in image_group]
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names)
class GotOcr2MLPBlock(SamMLPBlock): class GotOcr2MLPBlock(SamMLPBlock):
pass pass
@ -972,8 +524,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
__all__ = [ __all__ = [
"GotOcr2VisionConfig", "GotOcr2VisionConfig",
"GotOcr2Config", "GotOcr2Config",
"GotOcr2Processor",
"GotOcr2PreTrainedModel", "GotOcr2PreTrainedModel",
"GotOcr2ForConditionalGeneration", "GotOcr2ForConditionalGeneration",
"GotOcr2ImageProcessor",
] ]

View File

@ -1,9 +1,3 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_got_ocr2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #
@ -22,6 +16,8 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
@ -100,7 +96,7 @@ class GotOcr2Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"] valid_kwargs = ["chat_template"]
image_processor_class = "GotOcr2ImageProcessor" image_processor_class = "AutoImageProcessor"
tokenizer_class = "PreTrainedTokenizerFast" tokenizer_class = "PreTrainedTokenizerFast"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
@ -205,28 +201,29 @@ class GotOcr2Processor(ProcessorMixin):
box = output_kwargs["images_kwargs"].pop("box", [None]) box = output_kwargs["images_kwargs"].pop("box", [None])
color = output_kwargs["images_kwargs"].pop("color", None) color = output_kwargs["images_kwargs"].pop("color", None)
multi_page = output_kwargs["images_kwargs"].pop("multi_page") multi_page = output_kwargs["images_kwargs"].pop("multi_page")
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
max_patches = output_kwargs["images_kwargs"].pop("max_patches")
crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches")
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
if multi_page:
# save the number of pages per batch
num_pages_per_batch = [len(image_group) for image_group in images]
# flatten the list of images
images = [image for image_group in images for image in image_group]
else:
num_pages_per_batch = [1 for _ in range(len(images))]
# Load images as we need to know the image size # Load images as we need to know the image size
images = load_images(images) images = load_images(images)
image_sizes = [image.size for image in images]
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
num_patches_array = image_inputs.pop("num_patches")
if text is None: if text is None:
text = [] text = []
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): patch_indices = np.cumsum(num_pages_per_batch)
if crop_to_patches: for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)):
image_group = self.image_processor.crop_image_to_patches( current_patch_index = patch_indices[index - 1] if index > 0 else 0
image_group, num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages])
patch_size=output_kwargs["images_kwargs"].get("size"),
min_patches=min_patches,
max_patches=max_patches,
)
images[index] = image_group
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
if box_single[0] is not None: if box_single[0] is not None:
box_single = preprocess_box_annotation(box_single, image_group.size) box_single = preprocess_box_annotation(box_single, image_sizes[index])
query = ( query = (
f"{f'[{color_single}] ' if color_single is not None else ''}" f"{f'[{color_single}] ' if color_single is not None else ''}"
f"{str(box_single) if box_single[0] is not None else ''} " f"{str(box_single) if box_single[0] is not None else ''} "
@ -243,7 +240,7 @@ class GotOcr2Processor(ProcessorMixin):
+ self.message_start_token + self.message_start_token
+ "user\n" + "user\n"
+ self.img_start_token + self.img_start_token
+ self.img_pad_token * num_image_tokens * num_images + self.img_pad_token * num_image_tokens * num_patches
+ self.img_end_token + self.img_end_token
+ "\n" + "\n"
+ query + query
@ -252,22 +249,8 @@ class GotOcr2Processor(ProcessorMixin):
+ "assistant\n" + "assistant\n"
) )
text.append(prompt) text.append(prompt)
elif crop_to_patches:
for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
image_group = self.image_processor.crop_image_to_patches(
image_group,
patch_size=output_kwargs["images_kwargs"].get("size"),
min_patches=min_patches,
max_patches=max_patches,
)
images[index] = image_group
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
if multi_page or crop_to_patches:
# flatten images
images = [image for image_group in images for image in image_group]
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs}) return BatchFeature(data={**text_inputs, **image_inputs})
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):

View File

@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"]) requires_backends(self, ["torchvision"])
class GotOcr2ImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class LlavaImageProcessorFast(metaclass=DummyObject): class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"] _backends = ["torchvision"]

View File

@ -16,15 +16,22 @@
import unittest import unittest
from transformers.image_utils import SizeDict
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available(): if is_vision_available():
from transformers import GotOcr2ImageProcessor from transformers import GotOcr2ImageProcessor
if is_torchvision_available():
from transformers import GotOcr2ImageProcessorFast
class GotOcr2ImageProcessingTester(unittest.TestCase): class GotOcr2ImageProcessingTester(unittest.TestCase):
def __init__( def __init__(
@ -89,6 +96,7 @@ class GotOcr2ImageProcessingTester(unittest.TestCase):
@require_vision @require_vision
class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None
fast_image_processing_class = GotOcr2ImageProcessorFast if is_torchvision_available() else None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -99,17 +107,72 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict() return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self): def test_image_processor_properties(self):
image_processor = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
self.assertTrue(hasattr(image_processor, "do_resize")) image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "size")) self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "do_normalize")) self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "image_mean")) self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_std")) self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb")) self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
def test_slow_fast_equivalence_crop_to_patches(self):
dummy_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)[0]
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
def test_slow_fast_equivalence_batched_crop_to_patches(self):
# Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
# different resolutions in between
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
def test_crop_to_patches(self): def test_crop_to_patches(self):
image_processor = self.image_processing_class(**self.image_processor_dict) # test slow image processor
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0] image_processor = self.image_processor_list[0](**self.image_processor_dict)
processed_images = image_processor.crop_image_to_patches(image, 1, 6, use_thumbnail=True) image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)[0]
processed_images = image_processor.crop_image_to_patches(
image,
min_patches=1,
max_patches=6,
use_thumbnail=True,
patch_size={"height": 20, "width": 20},
)
self.assertEqual(len(processed_images), 5) self.assertEqual(len(processed_images), 5)
self.assertEqual(processed_images[0].size, (20, 20)) self.assertEqual(processed_images[0].shape[:2], (20, 20))
# test fast image processor (process batch)
image_processor = self.image_processor_list[1](**self.image_processor_dict)
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)[0]
processed_images = image_processor.crop_image_to_patches(
image.unsqueeze(0),
min_patches=1,
max_patches=6,
use_thumbnail=True,
patch_size=SizeDict(height=20, width=20),
)
self.assertEqual(len(processed_images[0]), 5)
self.assertEqual(processed_images.shape[-2:], (20, 20))