mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Uniformize model processors (#31368)
* add initial design for uniform processors + align model * add uniform processors for altclip + chinese_clip * add uniform processors for blip + blip2 * fix mutable default 👀 * add configuration test * handle structured kwargs w defaults + add test * protect torch-specific test * fix style * fix * rebase * update processor to generic kwargs + test * fix style * add sensible kwargs merge * update test * fix assertEqual * move kwargs merging to processing common * rework kwargs for type hinting * just get Unpack from extensions * run-slow[align] * handle kwargs passed as nested dict * add from_pretrained test for nested kwargs handling * [run-slow]align * update documentation + imports * update audio inputs * protect audio types, silly * try removing imports * make things simpler * simplerer * move out kwargs test to common mixin * [run-slow]align * skip tests for old processors * [run-slow]align, clip * !$#@!! protect imports, darn it * [run-slow]align, clip * [run-slow]align, clip * update common processor testing * add altclip * add chinese_clip * add pad_size * [run-slow]align, clip, chinese_clip, altclip * remove duplicated tests * fix * add blip, blip2, bridgetower Added tests for bridgetower which override common. Also modified common tests to force center cropping if existing * fix * update doc * improve documentation for default values * add model_max_length testing This parameter depends on tokenizers received. * Raise if kwargs are specified in two places * fix * removed copied from * match defaults * force padding * fix tokenizer test * clean defaults * move tests to common * add missing import * fix * adapt bridgetower tests to shortest edge * uniformize donut processor + tests * add wav2vec2 * extend common testing to audio processors * add testing + bert version * propagate common kwargs to different modalities * BC order of arguments * check py version * revert kwargs merging * add draft overlap test * update * fix blip2 and wav2vec due to updates * fix copies * ensure overlapping kwargs do not disappear * replace .pop by .get to handle duplicated kwargs * fix copies * fix missing import * add clearly wav2vec2_bert to uniformized models * fix copies * increase number of features * fix style * [run-slow] blip, blip2, bridgetower, donut, wav2vec2, wav2vec2_bert * [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert * fix concatenation * [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert * Update tests/test_processing_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * 🧹 * address comments * clean up + tests * [run-slow] instructblip, blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
2292be6c1b
commit
50290cf7a0
@ -19,9 +19,25 @@ Processor class for Blip.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class BlipProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class BlipProcessor(ProcessorMixin):
|
||||
@ -51,84 +67,53 @@ class BlipProcessor(ProcessorMixin):
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[BlipProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify either images or text.")
|
||||
|
||||
# Get only text
|
||||
if images is None:
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
return text_encoding
|
||||
|
||||
# add pixel_values
|
||||
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
|
||||
text_encoding = None
|
||||
|
||||
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
|
||||
# else, return the text encoding.
|
||||
output_kwargs = self._merge_kwargs(
|
||||
BlipProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if text is not None:
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
text_encoding = None
|
||||
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
if images is not None:
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
if text_encoding is not None:
|
||||
encoding_image_processor.update(text_encoding)
|
||||
if text_encoding is not None:
|
||||
encoding_image_processor.update(text_encoding)
|
||||
return encoding_image_processor
|
||||
|
||||
return encoding_image_processor
|
||||
return text_encoding
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -18,22 +18,38 @@ Processor class for BLIP-2.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Blip2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class Blip2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor.
|
||||
@ -67,58 +83,44 @@ class Blip2Processor(ProcessorMixin):
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Blip2ProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify either images or text.")
|
||||
|
||||
# Get only text
|
||||
if images is None:
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
return text_encoding
|
||||
|
||||
# add pixel_values
|
||||
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Blip2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
# BC for explicit return_tensors
|
||||
if "return_tensors" in output_kwargs["common_kwargs"]:
|
||||
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
|
||||
else:
|
||||
return_tensors = None
|
||||
encoding = BatchFeature(tensor_type=return_tensors)
|
||||
if text is not None:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
@ -126,24 +128,10 @@ class Blip2Processor(ProcessorMixin):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
text_encoding = {}
|
||||
_text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=None, # hardcode "None" here for prepending image tokens
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
|
||||
|
||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
||||
@ -164,14 +152,14 @@ class Blip2Processor(ProcessorMixin):
|
||||
)
|
||||
|
||||
# cast to desired return tensors type
|
||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
||||
else:
|
||||
text_encoding = None
|
||||
encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))
|
||||
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
|
||||
# else, return the text encoding.
|
||||
|
||||
if text_encoding is not None:
|
||||
encoding_image_processor.update(text_encoding)
|
||||
|
||||
return encoding_image_processor
|
||||
if images is not None:
|
||||
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
encoding.update(image_encoding)
|
||||
return encoding
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
@ -16,11 +16,29 @@
|
||||
Processor class for BridgeTower.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_normalize": True,
|
||||
"do_center_crop": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BridgeTowerProcessor(ProcessorMixin):
|
||||
@ -50,21 +68,9 @@ class BridgeTowerProcessor(ProcessorMixin):
|
||||
self,
|
||||
images,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[BridgeTowerProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method uses [`BridgeTowerImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
@ -72,28 +78,14 @@ class BridgeTowerProcessor(ProcessorMixin):
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
output_kwargs = self._merge_kwargs(
|
||||
BridgeTowerProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
# add pixel_values + pixel_mask
|
||||
encoding_image_processor = self.image_processor(
|
||||
images, return_tensors=return_tensors, do_normalize=True, do_center_crop=True, **kwargs
|
||||
)
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
encoding.update(encoding_image_processor)
|
||||
|
||||
return encoding
|
||||
|
@ -19,8 +19,15 @@ Processor class for Donut.
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class DonutProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
class DonutProcessor(ProcessorMixin):
|
||||
@ -63,7 +70,14 @@ class DonutProcessor(ProcessorMixin):
|
||||
self.current_processor = self.image_processor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[DonutProcessorKwargs],
|
||||
):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
|
||||
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
|
||||
@ -72,28 +86,29 @@ class DonutProcessor(ProcessorMixin):
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
images = kwargs.pop("images", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
images = args[0]
|
||||
args = args[1:]
|
||||
return self.current_processor(images, text, **kwargs)
|
||||
|
||||
if images is None and text is None:
|
||||
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
DonutProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
inputs = self.image_processor(images, *args, **kwargs)
|
||||
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif images is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
inputs["labels"] = encodings["input_ids"] # for BC
|
||||
inputs["input_ids"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
@ -17,12 +17,11 @@ Processor class for Idefics3.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
@ -30,11 +29,6 @@ from ...utils import logging
|
||||
if TYPE_CHECKING:
|
||||
from ...tokenization_utils_base import PreTokenizedInput
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Unpack
|
||||
else:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -19,19 +19,9 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import (
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
PreTokenizedInput,
|
||||
|
@ -16,13 +16,12 @@
|
||||
Processor class for OmDet-Turbo.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_transforms import center_to_corners_format
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -31,12 +30,6 @@ from ...utils import (
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Unpack
|
||||
else:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
class OmDetTurboTextKwargs(TextKwargs, total=False):
|
||||
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
|
||||
|
||||
|
@ -18,12 +18,18 @@ Speech processor class for Wav2Vec2
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
|
||||
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
class Wav2Vec2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single
|
||||
@ -66,35 +72,46 @@ class Wav2Vec2Processor(ProcessorMixin):
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
audio: AudioInput = None,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
images=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Wav2Vec2ProcessorKwargs],
|
||||
):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
||||
[`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context
|
||||
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
|
||||
[`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
sampling_rate = kwargs.pop("sampling_rate", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Wav2Vec2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(
|
||||
audio,
|
||||
**output_kwargs["audio_kwargs"],
|
||||
**output_kwargs["text_kwargs"],
|
||||
**output_kwargs["common_kwargs"],
|
||||
)
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
|
||||
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
|
@ -17,12 +17,18 @@ Speech processor class for Wav2Vec2-BERT
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
|
||||
from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
|
||||
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
|
||||
|
||||
class Wav2Vec2BertProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
class Wav2Vec2BertProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single
|
||||
@ -63,7 +69,14 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, audio=None, text=None, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
audio: AudioInput = None,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
images=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Wav2Vec2BertProcessorKwargs],
|
||||
):
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio`
|
||||
and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not
|
||||
@ -71,17 +84,15 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
|
||||
PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
|
||||
of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
|
||||
and T the sample length of the audio.
|
||||
kwargs (*optional*):
|
||||
Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the
|
||||
tokenizer.
|
||||
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
- **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`.
|
||||
@ -91,15 +102,18 @@ class Wav2Vec2BertProcessor(ProcessorMixin):
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`.
|
||||
"""
|
||||
|
||||
sampling_rate = kwargs.pop("sampling_rate", None)
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Wav2Vec2BertProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs)
|
||||
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
|
@ -820,6 +820,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"common_kwargs": {},
|
||||
}
|
||||
|
||||
used_keys = set()
|
||||
|
||||
# get defaults from set model processor kwargs if they exist
|
||||
for modality in default_kwargs:
|
||||
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
|
||||
@ -846,18 +848,29 @@ class ProcessorMixin(PushToHubMixin):
|
||||
f"in a dictionary for {modality} and as a **kwarg."
|
||||
)
|
||||
elif modality_key in kwargs:
|
||||
kwarg_value = kwargs.pop(modality_key, "__empty__")
|
||||
# we get a modality_key instead of popping it because modality-specific processors
|
||||
# can have overlapping kwargs
|
||||
kwarg_value = kwargs.get(modality_key, "__empty__")
|
||||
else:
|
||||
kwarg_value = "__empty__"
|
||||
if kwarg_value != "__empty__":
|
||||
output_kwargs[modality][modality_key] = kwarg_value
|
||||
# if something remains in kwargs, it belongs to common after flattening
|
||||
if set(kwargs) & set(default_kwargs):
|
||||
# here kwargs is dictionary-based since it shares keys with default set
|
||||
[output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()]
|
||||
used_keys.add(modality_key)
|
||||
|
||||
# Determine if kwargs is a flat dictionary or contains nested dictionaries
|
||||
if any(key in default_kwargs for key in kwargs):
|
||||
# kwargs is dictionary-based, and some keys match modality names
|
||||
for modality, subdict in kwargs.items():
|
||||
if modality in default_kwargs:
|
||||
for subkey, subvalue in subdict.items():
|
||||
if subkey not in used_keys:
|
||||
output_kwargs[modality][subkey] = subvalue
|
||||
used_keys.add(subkey)
|
||||
else:
|
||||
# here it's a flat dict
|
||||
output_kwargs["common_kwargs"].update(kwargs)
|
||||
# kwargs is a flat dictionary
|
||||
for key in kwargs:
|
||||
if key not in used_keys:
|
||||
output_kwargs["common_kwargs"][key] = kwargs[key]
|
||||
|
||||
# all modality-specific kwargs are updated with common kwargs
|
||||
for modality in output_kwargs:
|
||||
|
@ -17,17 +17,12 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||
from transformers import AltCLIPProcessor, CLIPImageProcessor, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import AltCLIPProcessor, CLIPImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = AltCLIPProcessor
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -139,3 +139,29 @@ class BlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
|
||||
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 24)
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -94,7 +94,7 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
encoded_tok = tokenizer(input_str, return_token_type_ids=False)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
|
||||
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
@ -107,7 +107,7 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
|
||||
self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"])
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
@ -138,4 +138,31 @@ class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
|
||||
self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"])
|
||||
self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"])
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 11)
|
||||
|
218
tests/models/bridgetower/test_processing_bridgetower.py
Normal file
218
tests/models/bridgetower/test_processing_bridgetower.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
BridgeTowerImageProcessor,
|
||||
BridgeTowerProcessor,
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
@require_vision
|
||||
class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = BridgeTowerProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = BridgeTowerImageProcessor()
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
|
||||
processor = BridgeTowerProcessor(image_processor, tokenizer)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
# Some kwargs tests are overriden from common tests to handle shortest_edge
|
||||
# and size_divisor behaviour
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component(
|
||||
"image_processor",
|
||||
crop_size={"shortest_edge": 234, "longest_edge": 234},
|
||||
)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {
|
||||
"crop_size": {"shortest_edge": 214},
|
||||
},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234})
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224})
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"shortest_edge": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"shortest_edge": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"shortest_edge": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
@ -14,16 +14,32 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import DonutProcessor
|
||||
from transformers import DonutImageProcessor, DonutProcessor, XLMRobertaTokenizerFast
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
class DonutProcessorTest(unittest.TestCase):
|
||||
class DonutProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = "naver-clova-ix/donut-base"
|
||||
processor_class = DonutProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.processor = DonutProcessor.from_pretrained(self.from_pretrained_id)
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = DonutImageProcessor()
|
||||
tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.from_pretrained_id)
|
||||
|
||||
processor = DonutProcessor(image_processor, tokenizer)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_token2json(self):
|
||||
expected_json = {
|
||||
@ -49,3 +65,30 @@ class DonutProcessorTest(unittest.TestCase):
|
||||
actual_json = self.processor.token2json(sequence)
|
||||
|
||||
self.assertDictEqual(actual_json, expected_json)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 7)
|
||||
|
@ -18,14 +18,19 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Wav2Vec2Processor
|
||||
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
@ -53,6 +58,9 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs_init):
|
||||
kwargs = self.add_kwargs_tokens_map.copy()
|
||||
kwargs.update(kwargs_init)
|
||||
@ -117,7 +125,6 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
@ -125,6 +132,22 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_padding_argument_not_ignored(self):
|
||||
# padding, or any other overlap arg between audio extractor and tokenizer
|
||||
# should be passed to both text and audio and not ignored
|
||||
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
batch_duration_in_seconds = [1, 3, 2, 6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
_ = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
@ -18,17 +18,21 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
# Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor
|
||||
class Wav2Vec2BertProcessorTest(unittest.TestCase):
|
||||
class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Wav2Vec2BertProcessor
|
||||
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
@ -40,7 +44,7 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
|
||||
"eos_token": "</s>",
|
||||
}
|
||||
feature_extractor_map = {
|
||||
"feature_size": 1,
|
||||
"feature_size": 80,
|
||||
"padding_value": 0.0,
|
||||
"sampling_rate": 16000,
|
||||
"return_attention_mask": False,
|
||||
@ -56,6 +60,9 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs_init):
|
||||
kwargs = self.add_kwargs_tokens_map.copy()
|
||||
kwargs.update(kwargs_init)
|
||||
@ -122,7 +129,6 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
|
||||
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
@ -130,6 +136,22 @@ class Wav2Vec2BertProcessorTest(unittest.TestCase):
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_padding_argument_not_ignored(self):
|
||||
# padding, or any other overlap arg between audio extractor and tokenizer
|
||||
# should be passed to both text and audio and not ignored
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
batch_duration_in_seconds = [1, 3, 2, 6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
# processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
|
||||
_ = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
@ -31,11 +32,7 @@ from transformers.testing_utils import (
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -48,6 +45,21 @@ def prepare_image_inputs():
|
||||
return image_inputs
|
||||
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ProcessorTesterMixin:
|
||||
@ -333,6 +345,135 @@ class ProcessorTesterMixin:
|
||||
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
# text + audio kwargs testing
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
else:
|
||||
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 117)
|
||||
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 112)
|
||||
|
||||
@require_torch
|
||||
def test_unstructured_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
def test_doubly_passed_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
raw_speech = floats_list((3, 1000))
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
audio_kwargs={"padding": "max_length"},
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_audio_nested(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
"audio_kwargs": {"padding": "max_length", "max_length": 66},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
|
||||
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
||||
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
|
Loading…
Reference in New Issue
Block a user