mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
🚨🚨🚨 Uniformize kwargs for TrOCR Processor (#34587)
* Make kwargs uniform for TrOCR * Add tests * Put back current_processor * Remove args * Add todo comment * Code review - breaking change
This commit is contained in:
parent
0b5b5e6a70
commit
89d7bf584f
@ -18,8 +18,16 @@ Processor class for TrOCR.
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
from ...processing_utils import ProcessorMixin
|
from ...image_processing_utils import BatchFeature
|
||||||
|
from ...image_utils import ImageInput
|
||||||
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
|
||||||
|
class TrOCRProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
_defaults = {}
|
||||||
|
|
||||||
|
|
||||||
class TrOCRProcessor(ProcessorMixin):
|
class TrOCRProcessor(ProcessorMixin):
|
||||||
@ -61,7 +69,14 @@ class TrOCRProcessor(ProcessorMixin):
|
|||||||
self.current_processor = self.image_processor
|
self.current_processor = self.image_processor
|
||||||
self._in_target_context_manager = False
|
self._in_target_context_manager = False
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: ImageInput = None,
|
||||||
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||||
|
audio=None,
|
||||||
|
videos=None,
|
||||||
|
**kwargs: Unpack[TrOCRProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
|
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
|
||||||
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
|
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
|
||||||
@ -70,21 +85,21 @@ class TrOCRProcessor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
# For backward compatibility
|
# For backward compatibility
|
||||||
if self._in_target_context_manager:
|
if self._in_target_context_manager:
|
||||||
return self.current_processor(*args, **kwargs)
|
return self.current_processor(images, **kwargs)
|
||||||
|
|
||||||
images = kwargs.pop("images", None)
|
|
||||||
text = kwargs.pop("text", None)
|
|
||||||
if len(args) > 0:
|
|
||||||
images = args[0]
|
|
||||||
args = args[1:]
|
|
||||||
|
|
||||||
if images is None and text is None:
|
if images is None and text is None:
|
||||||
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
TrOCRProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
inputs = self.image_processor(images, *args, **kwargs)
|
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
if text is not None:
|
if text is not None:
|
||||||
encodings = self.tokenizer(text, **kwargs)
|
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
if text is None:
|
if text is None:
|
||||||
return inputs
|
return inputs
|
||||||
|
129
tests/models/trocr/test_processor_trocr.py
Normal file
129
tests/models/trocr/test_processor_trocr.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers.models.xlm_roberta.tokenization_xlm_roberta import VOCAB_FILES_NAMES
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_sentencepiece,
|
||||||
|
require_tokenizers,
|
||||||
|
require_vision,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from transformers import TrOCRProcessor, ViTImageProcessor, XLMRobertaTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
@require_vision
|
||||||
|
class TrOCRProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
text_input_name = "labels"
|
||||||
|
processor_class = TrOCRProcessor
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] # fmt: skip
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
|
image_processor = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||||
|
tokenizer = XLMRobertaTokenizerFast.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||||
|
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return XLMRobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def test_save_load_pretrained_default(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
processor = TrOCRProcessor.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
|
||||||
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||||
|
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
|
||||||
|
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
|
||||||
|
|
||||||
|
def test_save_load_pretrained_additional_features(self):
|
||||||
|
processor = TrOCRProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||||
|
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||||
|
|
||||||
|
processor = TrOCRProcessor.from_pretrained(
|
||||||
|
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
|
||||||
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||||
|
|
||||||
|
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||||
|
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
|
||||||
|
|
||||||
|
def test_image_processor(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
input_feat_extract = image_processor(image_input, return_tensors="np")
|
||||||
|
input_processor = processor(images=image_input, return_tensors="np")
|
||||||
|
|
||||||
|
for key in input_feat_extract.keys():
|
||||||
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
|
def test_tokenizer(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
input_str = "lower newer"
|
||||||
|
|
||||||
|
encoded_processor = processor(text=input_str)
|
||||||
|
encoded_tok = tokenizer(input_str)
|
||||||
|
|
||||||
|
for key in encoded_tok.keys():
|
||||||
|
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||||
|
|
||||||
|
def test_processor_text(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), ["pixel_values", "labels"])
|
||||||
|
|
||||||
|
# test if it raises when no input is passed
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
processor()
|
||||||
|
|
||||||
|
def test_tokenizer_decode(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||||
|
|
||||||
|
decoded_processor = processor.batch_decode(predicted_ids)
|
||||||
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
|
self.assertListEqual(decoded_tok, decoded_processor)
|
Loading…
Reference in New Issue
Block a user