mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* Fix Fuyu image scaling bug
It could produce negative padding and hence inference errors for certain
image sizes.
* initial rework commit
* add batching capabilities, refactor image processing
* add functional batching for a list of images and texts
* make args explicit
* Fuyu processing update (#27133)
* Add file headers
* Add file headers
* First pass - preprocess method with standard args
* First pass image processor rework
* Small tweaks
* More args and docstrings
* Tidying iterating over batch
* Tidying up
* Modify to have quick tests (for now)
* Fix up
* BatchFeature
* Passing tests
* Add tests for processor
* Sense check when patchifying
* Add some tests
* FuyuBatchFeature
* Post-process box coordinates
* Update to `size` in processor
* Remove unused and duplicate constants
* Store unpadded dims after resize
* Fix up
* Return FuyuBatchFeature
* Get unpadded sizes after resize
* Update exception
* Fix return
* Convert input `<box>` coordinates to model format.
* Post-process point coords, support multiple boxes/points in a single
sequence
* Replace constants
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Preprocess List[List[image]]
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update to Amy's latest state.
* post-processing returns a list of tensors
* Fix error when target_sizes is None
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Review comments
* Update src/transformers/models/fuyu/image_processing_fuyu.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Fix up
* Fix up
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
* Fix conflicts in fuyu_follow_up_image_processing (#27228)
fixing conflicts and updating on main
* Revert "Fix conflicts in fuyu_follow_up_image_processing" (#27232)
Revert "Fix conflicts in fuyu_follow_up_image_processing (#27228)"
This reverts commit acce10b6c6
.
---------
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-72-126.ec2.internal>
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import is_torch_available, is_vision_available
|
|
from transformers.testing_utils import (
|
|
require_torch,
|
|
require_torchvision,
|
|
require_vision,
|
|
)
|
|
|
|
|
|
if is_torch_available() and is_vision_available():
|
|
import torch
|
|
|
|
from transformers import FuyuImageProcessor
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
|
|
@require_torch
|
|
@require_vision
|
|
@require_torchvision
|
|
class TestFuyuImageProcessor(unittest.TestCase):
|
|
def setUp(self):
|
|
self.size = {"height": 160, "width": 320}
|
|
self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0)
|
|
self.batch_size = 3
|
|
self.channels = 3
|
|
self.height = 300
|
|
self.width = 300
|
|
|
|
self.image_input = torch.rand(self.batch_size, self.channels, self.height, self.width)
|
|
|
|
self.image_patch_dim_h = 30
|
|
self.image_patch_dim_w = 30
|
|
self.sample_image = np.zeros((450, 210, 3), dtype=np.uint8)
|
|
self.sample_image_pil = Image.fromarray(self.sample_image)
|
|
|
|
def test_patches(self):
|
|
expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width)
|
|
|
|
patches_final = self.processor.patchify_image(image=self.image_input)
|
|
assert (
|
|
patches_final.shape[1] == expected_num_patches
|
|
), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}."
|
|
|
|
def test_scale_to_target_aspect_ratio(self):
|
|
# (h:450, w:210) fitting (160, 320) -> (160, 210*160/450)
|
|
scaled_image = self.processor.resize(self.sample_image, size=self.size)
|
|
self.assertEqual(scaled_image.shape[0], 160)
|
|
self.assertEqual(scaled_image.shape[1], 74)
|
|
|
|
def test_apply_transformation_numpy(self):
|
|
transformed_image = self.processor.preprocess(self.sample_image).images[0][0]
|
|
self.assertEqual(transformed_image.shape[1], 160)
|
|
self.assertEqual(transformed_image.shape[2], 320)
|
|
|
|
def test_apply_transformation_pil(self):
|
|
transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0]
|
|
self.assertEqual(transformed_image.shape[1], 160)
|
|
self.assertEqual(transformed_image.shape[2], 320)
|