mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fall back to slow image processor in ImageProcessingAuto when no fast processor available (#34785)
* refactor image_processing_auto logic * fix fast image processor tests * Fix tests fast vit image processor * Add safeguard when use_fast True and torchvision not available * change default use_fast back to None, add warnings * remove debugging print * call get_image_processor_class_from_name once
This commit is contained in:
parent
ca03842cdc
commit
5615a39369
@ -27,6 +27,7 @@ from transformers import AutoImageProcessor
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
|
||||
```
|
||||
Note that `use_fast` will be set to `True` by default in a future release.
|
||||
|
||||
When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise.
|
||||
|
||||
@ -42,22 +43,18 @@ images_processed = processor(images, return_tensors="pt", device="cuda")
|
||||
Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time:
|
||||
|
||||
<div class="flex">
|
||||
<div>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_padded.png" />
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_detr_fast_batched_compiled.png" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex">
|
||||
<div>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_single.png" />
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/benchmark_results_full_pipeline_rt_detr_fast_batched.png" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU.
|
||||
|
||||
|
@ -175,7 +175,7 @@ for model_type, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
||||
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
def image_processor_class_from_name(class_name: str):
|
||||
def get_image_processor_class_from_name(class_name: str):
|
||||
if class_name == "BaseImageProcessorFast":
|
||||
return BaseImageProcessorFast
|
||||
|
||||
@ -368,7 +368,7 @@ class AutoImageProcessor:
|
||||
identifier allowed by git.
|
||||
use_fast (`bool`, *optional*, defaults to `False`):
|
||||
Use a fast torchvision-base image processor if it is supported for a given model.
|
||||
If a fast tokenizer is not available for a given model, a normal numpy-based image processor
|
||||
If a fast image processor is not available for a given model, a normal numpy-based image processor
|
||||
is returned instead.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final image processor object. If `True`, then this
|
||||
@ -416,6 +416,7 @@ class AutoImageProcessor:
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
|
||||
use_fast = kwargs.pop("use_fast", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
@ -451,23 +452,23 @@ class AutoImageProcessor:
|
||||
if not is_timm_config_dict(config_dict):
|
||||
raise initial_exception
|
||||
|
||||
image_processor_class = config_dict.get("image_processor_type", None)
|
||||
image_processor_type = config_dict.get("image_processor_type", None)
|
||||
image_processor_auto_map = None
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
||||
|
||||
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
||||
# and if so, infer the image processor class from there.
|
||||
if image_processor_class is None and image_processor_auto_map is None:
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
||||
if feature_extractor_class is not None:
|
||||
image_processor_class = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
||||
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
|
||||
|
||||
# If we don't find the image processor class in the image processor config, let's try the model config.
|
||||
if image_processor_class is None and image_processor_auto_map is None:
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@ -475,18 +476,47 @@ class AutoImageProcessor:
|
||||
**kwargs,
|
||||
)
|
||||
# It could be in `config.image_processor_type``
|
||||
image_processor_class = getattr(config, "image_processor_type", None)
|
||||
image_processor_type = getattr(config, "image_processor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
|
||||
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
||||
|
||||
if image_processor_class is not None:
|
||||
# Update class name to reflect the use_fast option. If class is not found, None is returned.
|
||||
if use_fast is not None:
|
||||
if use_fast and not image_processor_class.endswith("Fast"):
|
||||
image_processor_class += "Fast"
|
||||
elif not use_fast and image_processor_class.endswith("Fast"):
|
||||
image_processor_class = image_processor_class[:-4]
|
||||
image_processor_class = image_processor_class_from_name(image_processor_class)
|
||||
image_processor_class = None
|
||||
# TODO: @yoni, change logic in v4.48 (when use_fast set to True by default)
|
||||
if image_processor_type is not None:
|
||||
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
||||
if use_fast is None:
|
||||
use_fast = image_processor_type.endswith("Fast")
|
||||
if not use_fast:
|
||||
logger.warning_once(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
# Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version.
|
||||
if use_fast and not is_torchvision_available():
|
||||
logger.warning_once(
|
||||
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
|
||||
)
|
||||
use_fast = False
|
||||
if use_fast:
|
||||
if not image_processor_type.endswith("Fast"):
|
||||
image_processor_type += "Fast"
|
||||
for _, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
||||
if image_processor_type in image_processors:
|
||||
break
|
||||
else:
|
||||
image_processor_type = image_processor_type[:-4]
|
||||
use_fast = False
|
||||
logger.warning_once(
|
||||
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
|
||||
" Falling back to the slow version."
|
||||
)
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
||||
else:
|
||||
image_processor_type = (
|
||||
image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type
|
||||
)
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
||||
|
||||
has_remote_code = image_processor_auto_map is not None
|
||||
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
|
||||
|
@ -254,6 +254,7 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
size = size if size is not None else self.size
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
# Make hashable for cache
|
||||
size = SizeDict(**size)
|
||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
||||
|
@ -140,6 +140,7 @@ class AutoImageProcessorTest(unittest.TestCase):
|
||||
def test_use_fast_selection(self):
|
||||
checkpoint = "hf-internal-testing/tiny-random-vit"
|
||||
|
||||
# TODO: @yoni, change in v4.48 (when use_fast set to True by default)
|
||||
# Slow image processor is selected by default
|
||||
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
self.assertIsInstance(image_processor, ViTImageProcessor)
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
|
||||
@ -669,6 +669,7 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torchvision
|
||||
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
@ -724,6 +725,7 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torchvision
|
||||
def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self):
|
||||
# prepare image, target and masks_path
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
|
@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
@ -374,6 +374,7 @@ class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torchvision
|
||||
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations
|
||||
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
|
@ -21,13 +21,13 @@ import unittest
|
||||
from transformers import BertTokenizerFast
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
|
||||
from transformers.testing_utils import require_tokenizers, require_vision
|
||||
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
|
||||
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor
|
||||
from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor, ViTImageProcessorFast
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@ -63,6 +63,8 @@ class VisionTextDualEncoderProcessorTest(ProcessorTesterMixin, unittest.TestCase
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
if is_torchvision_available():
|
||||
return ViTImageProcessorFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
@ -81,7 +83,7 @@ class VisionTextDualEncoderProcessorTest(ProcessorTesterMixin, unittest.TestCase
|
||||
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
|
||||
self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = VisionTextDualEncoderProcessor(
|
||||
@ -100,7 +102,7 @@ class VisionTextDualEncoderProcessorTest(ProcessorTesterMixin, unittest.TestCase
|
||||
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
|
||||
self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
@ -110,8 +112,8 @@ class VisionTextDualEncoderProcessorTest(ProcessorTesterMixin, unittest.TestCase
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = image_processor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
input_feat_extract = image_processor(image_input, return_tensors="pt")
|
||||
input_processor = processor(images=image_input, return_tensors="pt")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
@ -228,14 +228,15 @@ class ImageProcessingTestMixin:
|
||||
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
|
||||
|
||||
def test_image_processor_save_load_with_autoimageprocessor(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
for i, image_processing_class in enumerate(self.image_processor_list):
|
||||
image_processor_first = image_processing_class(**self.image_processor_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
|
||||
image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
|
||||
use_fast = i == 1
|
||||
image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=use_fast)
|
||||
|
||||
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user