Uniformize model processors (#31368)

* add initial design for uniform processors + align model

* add uniform processors for altclip + chinese_clip

* add uniform processors for blip + blip2

* fix mutable default 👀

* add configuration test

* handle structured kwargs w defaults + add test

* protect torch-specific test

* fix style

* fix

* rebase

* update processor to generic kwargs + test

* fix style

* add sensible kwargs merge

* update test

* fix assertEqual

* move kwargs merging to processing common

* rework kwargs for type hinting

* just get Unpack from extensions

* run-slow[align]

* handle kwargs passed as nested dict

* add from_pretrained test for nested kwargs handling

* [run-slow]align

* update documentation + imports

* update audio inputs

* protect audio types, silly

* try removing imports

* make things simpler

* simplerer

* move out kwargs test to common mixin

* [run-slow]align

* skip tests for old processors

* [run-slow]align, clip

* !$#@!! protect imports, darn it

* [run-slow]align, clip

* [run-slow]align, clip

* update common processor testing

* add altclip

* add chinese_clip

* add pad_size

* [run-slow]align, clip, chinese_clip, altclip

* remove duplicated tests

* fix

* add blip, blip2, bridgetower

Added tests for bridgetower which override common. Also modified common
tests to force center cropping if existing

* fix

* update doc

* improve documentation for default values

* add model_max_length testing

This parameter depends on tokenizers received.

* Raise if kwargs are specified in two places

* fix

* removed copied from

* match defaults

* force padding

* fix tokenizer test

* clean defaults

* move tests to common

* add missing import

* fix

* adapt bridgetower tests to shortest edge

* uniformize donut processor + tests

* add wav2vec2

* extend common testing to audio processors

* add testing + bert version

* propagate common kwargs to different modalities

* BC order of arguments

* check py version

* revert kwargs merging

* add draft overlap test

* update

* fix blip2 and wav2vec due to updates

* fix copies

* ensure overlapping kwargs do not disappear

* replace .pop by .get to handle duplicated kwargs

* fix copies

* fix missing import

* add clearly wav2vec2_bert to uniformized models

* fix copies

* increase number of features

* fix style

* [run-slow] blip, blip2, bridgetower, donut, wav2vec2, wav2vec2_bert

* [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

* fix concatenation

* [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

* Update tests/test_processing_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* 🧹

* address comments

* clean up + tests

* [run-slow] instructblip, blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Pablo Montalvo 2024-10-02 10:41:08 +02:00 committed by GitHub
parent 2292be6c1b
commit 50290cf7a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 769 additions and 273 deletions

View File

@ -19,9 +19,25 @@ Processor class for Blip.
from typing import List, Optional, Union
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
class BlipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
class BlipProcessor(ProcessorMixin):
@ -51,84 +67,53 @@ class BlipProcessor(ProcessorMixin):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[BlipProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding
# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
text_encoding = None
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.
output_kwargs = self._merge_kwargs(
BlipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
else:
text_encoding = None
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
if text_encoding is not None:
encoding_image_processor.update(text_encoding)
if text_encoding is not None:
encoding_image_processor.update(text_encoding)
return encoding_image_processor
return encoding_image_processor
return text_encoding
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,22 +18,38 @@ Processor class for BLIP-2.
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType, logging
from ...utils import logging
logger = logging.get_logger(__name__)
class Blip2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
class Blip2Processor(ProcessorMixin):
r"""
Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor.
@ -67,58 +83,44 @@ class Blip2Processor(ProcessorMixin):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[Blip2ProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding
# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
output_kwargs = self._merge_kwargs(
Blip2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
else:
return_tensors = None
encoding = BatchFeature(tensor_type=return_tensors)
if text is not None:
if isinstance(text, str):
text = [text]
@ -126,24 +128,10 @@ class Blip2Processor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
text_encoding = {}
_text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=None, # hardcode "None" here for prepending image tokens
**kwargs,
)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
@ -164,14 +152,14 @@ class Blip2Processor(ProcessorMixin):
)
# cast to desired return tensors type
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
else:
text_encoding = None
encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.
if text_encoding is not None:
encoding_image_processor.update(text_encoding)
return encoding_image_processor
if images is not None:
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding)
return encoding
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
def batch_decode(self, *args, **kwargs):

View File

@ -16,11 +16,29 @@
Processor class for BridgeTower.
"""
from typing import List, Optional, Union
from typing import List, Union
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {
"do_normalize": True,
"do_center_crop": True,
},
}
class BridgeTowerProcessor(ProcessorMixin):
@ -50,21 +68,9 @@ class BridgeTowerProcessor(ProcessorMixin):
self,
images,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
audio=None,
videos=None,
**kwargs: Unpack[BridgeTowerProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`BridgeTowerImageProcessor.__call__`] method to prepare image(s) for the model, and
@ -72,28 +78,14 @@ class BridgeTowerProcessor(ProcessorMixin):
Please refer to the docstring of the above two methods for more information.
"""
encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
output_kwargs = self._merge_kwargs(
BridgeTowerProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
# add pixel_values + pixel_mask
encoding_image_processor = self.image_processor(
images, return_tensors=return_tensors, do_normalize=True, do_center_crop=True, **kwargs
)
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(encoding_image_processor)
return encoding

View File

@ -19,8 +19,15 @@ Processor class for Donut.
import re
import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
class DonutProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}
class DonutProcessor(ProcessorMixin):
@ -63,7 +70,14 @@ class DonutProcessor(ProcessorMixin):
self.current_processor = self.image_processor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[DonutProcessorKwargs],
):
"""
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
@ -72,28 +86,29 @@ class DonutProcessor(ProcessorMixin):
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
images = kwargs.pop("images", None)
text = kwargs.pop("text", None)
if len(args) > 0:
images = args[0]
args = args[1:]
return self.current_processor(images, text, **kwargs)
if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")
output_kwargs = self._merge_kwargs(
DonutProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
inputs = self.image_processor(images, *args, **kwargs)
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
encodings = self.tokenizer(text, **kwargs)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
inputs["labels"] = encodings["input_ids"] # for BC
inputs["input_ids"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs):

View File

@ -17,12 +17,11 @@ Processor class for Idefics3.
"""
import re
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
from ...utils import logging
@ -30,11 +29,6 @@ from ...utils import logging
if TYPE_CHECKING:
from ...tokenization_utils_base import PreTokenizedInput
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
logger = logging.get_logger(__name__)

View File

@ -19,19 +19,9 @@ from typing import List, Optional, Union
import numpy as np
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
)
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
BatchEncoding,
PreTokenizedInput,

View File

@ -16,13 +16,12 @@
Processor class for OmDet-Turbo.
"""
import sys
from typing import List, Optional, Tuple, Union
from ...feature_extraction_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import (
TensorType,
@ -31,12 +30,6 @@ from ...utils import (
)
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
class OmDetTurboTextKwargs(TextKwargs, total=False):
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]

View File

@ -18,12 +18,18 @@ Speech processor class for Wav2Vec2
import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
class Wav2Vec2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}
class Wav2Vec2Processor(ProcessorMixin):
r"""
Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single
@ -66,35 +72,46 @@ class Wav2Vec2Processor(ProcessorMixin):
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
def __call__(self, *args, **kwargs):
def __call__(
self,
audio: AudioInput = None,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
images=None,
videos=None,
**kwargs: Unpack[Wav2Vec2ProcessorKwargs],
):
"""
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
[`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
output_kwargs = self._merge_kwargs(
Wav2Vec2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(
audio,
**output_kwargs["audio_kwargs"],
**output_kwargs["text_kwargs"],
**output_kwargs["common_kwargs"],
)
if audio is not None:
inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
if text is not None:
encodings = self.tokenizer(text, **kwargs)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is None:
return inputs

View File

@ -17,12 +17,18 @@ Speech processor class for Wav2Vec2-BERT
"""
import warnings
from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
class Wav2Vec2BertProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}
class Wav2Vec2BertProcessor(ProcessorMixin):
r"""
Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single
@ -63,7 +69,14 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
def __call__(self, audio=None, text=None, **kwargs):
def __call__(
self,
audio: AudioInput = None,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
images=None,
videos=None,
**kwargs: Unpack[Wav2Vec2BertProcessorKwargs],
):
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio`
and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not
@ -71,17 +84,15 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
and T the sample length of the audio.
kwargs (*optional*):
Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
tokenizer.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`.
@ -91,15 +102,18 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`.
"""
sampling_rate = kwargs.pop("sampling_rate", None)
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
output_kwargs = self._merge_kwargs(
Wav2Vec2BertProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if audio is not None:
inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs)
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
if text is not None:
encodings = self.tokenizer(text, **kwargs)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is None:
return inputs

View File

@ -820,6 +820,8 @@ class ProcessorMixin(PushToHubMixin):
"common_kwargs": {},
}
used_keys = set()
# get defaults from set model processor kwargs if they exist
for modality in default_kwargs:
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
@ -846,18 +848,29 @@ class ProcessorMixin(PushToHubMixin):
f"in a dictionary for {modality} and as a **kwarg."
)
elif modality_key in kwargs:
kwarg_value = kwargs.pop(modality_key, "__empty__")
# we get a modality_key instead of popping it because modality-specific processors
# can have overlapping kwargs
kwarg_value = kwargs.get(modality_key, "__empty__")
else:
kwarg_value = "__empty__"
if kwarg_value != "__empty__":
output_kwargs[modality][modality_key] = kwarg_value
# if something remains in kwargs, it belongs to common after flattening
if set(kwargs) & set(default_kwargs):
# here kwargs is dictionary-based since it shares keys with default set
[output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()]
used_keys.add(modality_key)
# Determine if kwargs is a flat dictionary or contains nested dictionaries
if any(key in default_kwargs for key in kwargs):
# kwargs is dictionary-based, and some keys match modality names
for modality, subdict in kwargs.items():
if modality in default_kwargs:
for subkey, subvalue in subdict.items():
if subkey not in used_keys:
output_kwargs[modality][subkey] = subvalue
used_keys.add(subkey)
else:
# here it's a flat dict
output_kwargs["common_kwargs"].update(kwargs)
# kwargs is a flat dictionary
for key in kwargs:
if key not in used_keys:
output_kwargs["common_kwargs"][key] = kwargs[key]
# all modality-specific kwargs are updated with common kwargs
for modality in output_kwargs:

View File

@ -17,17 +17,12 @@
import tempfile
import unittest
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
from transformers import AltCLIPProcessor, CLIPImageProcessor, XLMRobertaTokenizer, XLMRobertaTokenizerFast
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import AltCLIPProcessor, CLIPImageProcessor
@require_vision
class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = AltCLIPProcessor

View File

@ -17,7 +17,7 @@ import unittest
import pytest
from transformers.testing_utils import require_vision
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -139,3 +139,29 @@ class BlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 24)

View File

@ -17,7 +17,7 @@ import unittest
import pytest
from transformers.testing_utils import require_vision
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -94,7 +94,7 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
encoded_tok = tokenizer(input_str, return_token_type_ids=False)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
def test_processor(self):
image_processor = self.get_image_processor()
@ -107,7 +107,7 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
@ -138,4 +138,31 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs = processor(text=input_str, images=image_input)
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"])
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 11)

View File

@ -0,0 +1,218 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from PIL import Image
from transformers import (
AutoProcessor,
BridgeTowerImageProcessor,
BridgeTowerProcessor,
RobertaTokenizerFast,
)
@require_vision
class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = BridgeTowerProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = BridgeTowerImageProcessor()
tokenizer = RobertaTokenizerFast.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
processor = BridgeTowerProcessor(image_processor, tokenizer)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
# Some kwargs tests are overriden from common tests to handle shortest_edge
# and size_divisor behaviour
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component(
"image_processor",
crop_size={"shortest_edge": 234, "longest_edge": 234},
)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {
"crop_size": {"shortest_edge": 214},
},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234})
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224})
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"shortest_edge": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 6)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"shortest_edge": 214},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"shortest_edge": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

View File

@ -14,16 +14,32 @@
# limitations under the License.
import tempfile
import unittest
from transformers import DonutProcessor
from transformers import DonutImageProcessor, DonutProcessor, XLMRobertaTokenizerFast
from transformers.testing_utils import (
require_torch,
require_vision,
)
from ...test_processing_common import ProcessorTesterMixin
class DonutProcessorTest(unittest.TestCase):
class DonutProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "naver-clova-ix/donut-base"
processor_class = DonutProcessor
def setUp(self):
self.processor = DonutProcessor.from_pretrained(self.from_pretrained_id)
self.tmpdirname = tempfile.mkdtemp()
image_processor = DonutImageProcessor()
tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.from_pretrained_id)
processor = DonutProcessor(image_processor, tokenizer)
processor.save_pretrained(self.tmpdirname)
def test_token2json(self):
expected_json = {
@ -49,3 +65,30 @@ class DonutProcessorTest(unittest.TestCase):
actual_json = self.processor.token2json(sequence)
self.assertDictEqual(actual_json, expected_json)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 7)

View File

@ -18,14 +18,19 @@ import shutil
import tempfile
import unittest
import numpy as np
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.utils import FEATURE_EXTRACTOR_NAME
from ...test_processing_common import ProcessorTesterMixin
from .test_feature_extraction_wav2vec2 import floats_list
class Wav2Vec2ProcessorTest(unittest.TestCase):
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2Processor
def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
vocab_tokens = dict(zip(vocab, range(len(vocab))))
@ -53,6 +58,9 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
tokenizer = self.get_tokenizer()
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs_init):
kwargs = self.add_kwargs_tokens_map.copy()
kwargs.update(kwargs_init)
@ -117,7 +125,6 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
@ -125,6 +132,22 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()

View File

@ -18,17 +18,21 @@ import shutil
import tempfile
import unittest
import numpy as np
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor
from transformers.utils import FEATURE_EXTRACTOR_NAME
from ...test_processing_common import ProcessorTesterMixin
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
# Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor
class Wav2Vec2BertProcessorTest(unittest.TestCase):
class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2BertProcessor
def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
vocab_tokens = dict(zip(vocab, range(len(vocab))))
@ -40,7 +44,7 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
"eos_token": "</s>",
}
feature_extractor_map = {
"feature_size": 1,
"feature_size": 80,
"padding_value": 0.0,
"sampling_rate": 16000,
"return_attention_mask": False,
@ -56,6 +60,9 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
tokenizer = self.get_tokenizer()
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs_init):
kwargs = self.add_kwargs_tokens_map.copy()
kwargs.update(kwargs_init)
@ -122,7 +129,6 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
@ -130,6 +136,22 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
# processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()

View File

@ -16,6 +16,7 @@
import inspect
import json
import random
import tempfile
from typing import Optional
@ -31,11 +32,7 @@ from transformers.testing_utils import (
from transformers.utils import is_vision_available
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
global_rng = random.Random()
if is_vision_available():
from PIL import Image
@ -48,6 +45,21 @@ def prepare_image_inputs():
return image_inputs
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_vision
class ProcessorTesterMixin:
@ -333,6 +345,135 @@ class ProcessorTesterMixin:
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
# text + audio kwargs testing
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
else:
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 117)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 117)
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 112)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 112)
@require_torch
def test_unstructured_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(
text=input_str,
audio=raw_speech,
return_tensors="pt",
padding="max_length",
max_length=76,
)
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
@require_torch
def test_doubly_passed_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
raw_speech = floats_list((3, 1000))
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
audio=raw_speech,
audio_kwargs={"padding": "max_length"},
padding="max_length",
)
@require_torch
@require_vision
def test_structured_kwargs_audio_nested(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
raw_speech = floats_list((3, 1000))
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
"audio_kwargs": {"padding": "max_length", "max_length": 66},
}
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
def test_overlapping_text_kwargs_handling(self):