mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add molbap's commit
This commit is contained in:
parent
bb1f18bb3b
commit
a476c6ee88
@ -16,13 +16,14 @@
|
||||
Processor class for Grounding DINO.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_transforms import center_to_corners_format
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -31,12 +32,19 @@ else:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import TensorType, is_torch_available
|
||||
from ...utils import ExplicitEnum, TensorType, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
||||
|
||||
|
||||
class AnnotationFormat(ExplicitEnum):
|
||||
COCO_DETECTION = "coco_detection"
|
||||
COCO_PANOPTIC = "coco_panoptic"
|
||||
|
||||
|
||||
def get_phrases_from_posmap(posmaps, input_ids):
|
||||
"""Get token ids of phrases from posmaps and input_ids.
|
||||
@ -64,7 +72,16 @@ def get_phrases_from_posmap(posmaps, input_ids):
|
||||
return token_ids
|
||||
|
||||
|
||||
class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
|
||||
return_segmentation_masks: Optional[bool]
|
||||
masks_path: Optional[Union[str, pathlib.Path]]
|
||||
do_convert_annotations: Optional[bool]
|
||||
format: Optional[Union[str, AnnotationFormat]]
|
||||
|
||||
|
||||
class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: GroundingDinoImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
|
@ -150,6 +150,8 @@ class ImagesKwargs(TypedDict, total=False):
|
||||
Standard deviation to use if normalizing the image.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
|
||||
pad_size (`Dict[str, int]`, *optional*):
|
||||
The size `{"height": int, "width" int}` to pad the images to.
|
||||
do_center_crop (`bool`, *optional*):
|
||||
Whether to center crop the image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
@ -169,6 +171,7 @@ class ImagesKwargs(TypedDict, total=False):
|
||||
image_mean: Optional[Union[float, List[float]]]
|
||||
image_std: Optional[Union[float, List[float]]]
|
||||
do_pad: Optional[bool]
|
||||
pad_size: Optional[Dict[str, int]]
|
||||
do_center_crop: Optional[bool]
|
||||
data_format: Optional[ChannelDimension]
|
||||
input_data_format: Optional[Union[str, ChannelDimension]]
|
||||
@ -320,7 +323,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
feature_extractor_class = None
|
||||
tokenizer_class = None
|
||||
_auto_class = None
|
||||
valid_kwargs: List[str] = []
|
||||
|
||||
# args have to match the attributes class attribute
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -649,15 +651,14 @@ class ProcessorMixin(PushToHubMixin):
|
||||
processor_dict = processor_dict.copy()
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
|
||||
# If we don't pop, some specific kwargs will raise a warning
|
||||
# Unlike image processors or feature extractors whose `__init__` accept `kwargs`, processor don't have `kwargs`.
|
||||
# We have to pop up some unused (but specific) arguments to make it work.
|
||||
if "processor_class" in processor_dict:
|
||||
del processor_dict["processor_class"]
|
||||
|
||||
if "auto_map" in processor_dict:
|
||||
del processor_dict["auto_map"]
|
||||
|
||||
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
|
||||
processor = cls(*args, **processor_dict)
|
||||
|
||||
# Update processor with kwargs if needed
|
||||
@ -665,7 +666,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
if hasattr(processor, key):
|
||||
setattr(processor, key, kwargs.pop(key))
|
||||
|
||||
kwargs.update(unused_kwargs)
|
||||
logger.info(f"Processor {processor}")
|
||||
if return_unused_kwargs:
|
||||
return processor, kwargs
|
||||
@ -743,38 +743,34 @@ class ProcessorMixin(PushToHubMixin):
|
||||
if modality_key in tokenizer_init_kwargs:
|
||||
default_kwargs[modality][modality_key] = tokenizer_init_kwargs[modality_key]
|
||||
# now defaults kwargs are updated with the tokenizers defaults.
|
||||
# pass defaults to output dictionary
|
||||
output_kwargs.update(default_kwargs)
|
||||
|
||||
# gather common kwargs and remove them from individual kwargs if present
|
||||
common_kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key not in ModelProcessorKwargs.__annotations__["text_kwargs"].__annotations__
|
||||
and key not in ModelProcessorKwargs.__annotations__["images_kwargs"].__annotations__
|
||||
and key not in ModelProcessorKwargs.__annotations__["audio_kwargs"].__annotations__
|
||||
and key not in ModelProcessorKwargs.__annotations__["videos_kwargs"].__annotations__
|
||||
}
|
||||
|
||||
# ensure common kwargs are propagated to all relevant modalities
|
||||
for key, value in common_kwargs.items():
|
||||
for modality in output_kwargs:
|
||||
if modality != "common_kwargs":
|
||||
output_kwargs[modality][key] = value
|
||||
|
||||
# remove common kwargs from the kwargs to process the rest
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in common_kwargs}
|
||||
|
||||
# update modality kwargs with passed kwargs
|
||||
non_modality_kwargs = set(kwargs) - set(output_kwargs)
|
||||
for modality in output_kwargs:
|
||||
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
||||
# check if we received a structured kwarg dict or not to handle it correctly
|
||||
if modality in kwargs:
|
||||
kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
|
||||
# check if this key was passed as a flat kwarg.
|
||||
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
|
||||
raise ValueError(
|
||||
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
|
||||
)
|
||||
if modality in kwargs and modality_key in kwargs[modality]:
|
||||
output_kwargs[modality][modality_key] = kwargs[modality][modality_key]
|
||||
elif modality_key in kwargs:
|
||||
kwarg_value = kwargs.pop(modality_key, "__empty__")
|
||||
else:
|
||||
kwarg_value = "__empty__"
|
||||
if kwarg_value != "__empty__":
|
||||
output_kwargs[modality][modality_key] = kwarg_value
|
||||
# if something remains in kwargs, it belongs to common after flattening
|
||||
if set(kwargs) & set(default_kwargs):
|
||||
# here kwargs is dictionary-based since it shares keys with default set
|
||||
[output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()]
|
||||
else:
|
||||
# here it's a flat dict
|
||||
output_kwargs["common_kwargs"].update(kwargs)
|
||||
|
||||
# all modality-specific kwargs are updated with common kwargs
|
||||
for modality in output_kwargs:
|
||||
output_kwargs[modality].update(output_kwargs["common_kwargs"])
|
||||
output_kwargs[modality][modality_key] = kwargs[modality_key]
|
||||
return output_kwargs
|
||||
|
||||
@classmethod
|
||||
@ -890,19 +886,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
first_attribute = getattr(self, self.attributes[0])
|
||||
return getattr(first_attribute, "model_input_names", None)
|
||||
|
||||
@staticmethod
|
||||
def validate_init_kwargs(processor_config, valid_kwargs):
|
||||
kwargs_from_config = processor_config.keys()
|
||||
unused_kwargs = {}
|
||||
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
|
||||
if unused_keys:
|
||||
unused_key_str = ", ".join(unused_keys)
|
||||
logger.warning(
|
||||
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
|
||||
)
|
||||
unused_kwargs = {k: processor_config[k] for k in unused_keys}
|
||||
return unused_kwargs
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]]],
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
|
||||
@ -38,15 +39,31 @@ from transformers.testing_utils import (
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_torch
|
||||
class ProcessorTesterMixin:
|
||||
processor_class = None
|
||||
|
||||
@ -60,7 +77,10 @@ class ProcessorTesterMixin:
|
||||
component_class_name = component_class_name[0]
|
||||
|
||||
component_class = processor_class_from_name(component_class_name)
|
||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||
if hasattr(self, "tmpdirname"):
|
||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||
elif hasattr(self, "model_id"):
|
||||
component = component_class.from_pretrained(self.model_id, **kwargs) # noqa
|
||||
|
||||
return component
|
||||
|
||||
@ -126,13 +146,13 @@ class ProcessorTesterMixin:
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
@ -141,15 +161,15 @@ class ProcessorTesterMixin:
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234), size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
||||
@ -160,13 +180,15 @@ class ProcessorTesterMixin:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=112)
|
||||
inputs = processor(
|
||||
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
|
||||
)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
@require_torch
|
||||
@ -174,16 +196,17 @@ class ProcessorTesterMixin:
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234), size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, crop_size=[224, 224])
|
||||
inputs = processor(text=input_str, images=image_input, crop_size=[224, 224], size=[224, 224])
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
|
||||
|
||||
@require_torch
|
||||
@ -193,7 +216,8 @@ class ProcessorTesterMixin:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -204,6 +228,7 @@ class ProcessorTesterMixin:
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
@ -218,7 +243,8 @@ class ProcessorTesterMixin:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -229,10 +255,10 @@ class ProcessorTesterMixin:
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
@ -244,7 +270,8 @@ class ProcessorTesterMixin:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -265,7 +292,8 @@ class ProcessorTesterMixin:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -275,7 +303,7 @@ class ProcessorTesterMixin:
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}, "size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
@ -294,7 +322,8 @@ class ProcessorTesterMixin:
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
@ -303,7 +332,7 @@ class ProcessorTesterMixin:
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}, "size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
@ -312,6 +341,133 @@ class ProcessorTesterMixin:
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# text + audio kwargs testing
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 117)
|
||||
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 112)
|
||||
|
||||
@require_torch
|
||||
def test_unstructured_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
def test_doubly_passed_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
raw_speech = floats_list((3, 1000))
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
audio_kwargs={"padding": "max_length"},
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_audio_nested(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"audio_kwargs": {"padding": "max_length", "max_length": 66},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
|
||||
|
||||
class MyProcessor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
Loading…
Reference in New Issue
Block a user