mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fixing tests for Perceiver (#14739)
* Adding some slow test to check for perceiver at least from a high level. * Re-enabling fast tests for Perceiver ImageClassification. * Perceiver might try to run without Tokenizer (Fast doesn't exist) and with FeatureExtractor some text only pipelines. * Oops. * Adding a comment for `update_config_with_model_class`. * Remove `model_architecture` to get `tiny_config`. * Finalize rebase. * Smarter way to handle undefined FastTokenizer. * Remove old code. * Addressing some nits. * Don't instantiate `None`.
This commit is contained in:
parent
322d416916
commit
546a91abe9
@ -1268,6 +1268,7 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
pixel_values=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
@ -1296,6 +1297,10 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
|
|||||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||||
"""
|
"""
|
||||||
|
if inputs is not None and pixel_values is not None:
|
||||||
|
raise ValueError("You cannot use both `inputs` and `pixel_values`")
|
||||||
|
elif inputs is None and pixel_values is not None:
|
||||||
|
inputs = pixel_values
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
outputs = self.perceiver(
|
outputs = self.perceiver(
|
||||||
@ -1399,6 +1404,7 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
pixel_values=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
@ -1427,6 +1433,10 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
|
|||||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||||
"""
|
"""
|
||||||
|
if inputs is not None and pixel_values is not None:
|
||||||
|
raise ValueError("You cannot use both `inputs` and `pixel_values`")
|
||||||
|
elif inputs is None and pixel_values is not None:
|
||||||
|
inputs = pixel_values
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
outputs = self.perceiver(
|
outputs = self.perceiver(
|
||||||
|
@ -528,8 +528,8 @@ def pipeline(
|
|||||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||||
|
|
||||||
if task in {"audio-classification"}:
|
if task in {"audio-classification", "image-classification"}:
|
||||||
# Audio classification will never require a tokenizer.
|
# These will never require a tokenizer.
|
||||||
# the model on the other hand might have a tokenizer, but
|
# the model on the other hand might have a tokenizer, but
|
||||||
# the files could be missing from the hub, instead of failing
|
# the files could be missing from the hub, instead of failing
|
||||||
# on such repos, we just force to not load it.
|
# on such repos, we just force to not load it.
|
||||||
|
@ -77,12 +77,15 @@ def get_tiny_config_from_class(configuration_class):
|
|||||||
model_tester = model_tester_class(parent=None)
|
model_tester = model_tester_class(parent=None)
|
||||||
|
|
||||||
if hasattr(model_tester, "get_pipeline_config"):
|
if hasattr(model_tester, "get_pipeline_config"):
|
||||||
return model_tester.get_pipeline_config()
|
config = model_tester.get_pipeline_config()
|
||||||
elif hasattr(model_tester, "get_config"):
|
elif hasattr(model_tester, "get_config"):
|
||||||
return model_tester.get_config()
|
config = model_tester.get_config()
|
||||||
else:
|
else:
|
||||||
|
config = None
|
||||||
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")
|
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=100)
|
@lru_cache(maxsize=100)
|
||||||
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
||||||
@ -100,11 +103,17 @@ def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config):
|
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_extractor_class):
|
||||||
try:
|
try:
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
|
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
|
||||||
except Exception:
|
except Exception:
|
||||||
feature_extractor = None
|
try:
|
||||||
|
if feature_extractor_class is not None:
|
||||||
|
feature_extractor = feature_extractor_class()
|
||||||
|
else:
|
||||||
|
feature_extractor = None
|
||||||
|
except Exception:
|
||||||
|
feature_extractor = None
|
||||||
if hasattr(tiny_config, "image_size") and feature_extractor:
|
if hasattr(tiny_config, "image_size") and feature_extractor:
|
||||||
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
|
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
|
||||||
|
|
||||||
@ -168,7 +177,9 @@ class PipelineTestCaseMeta(type):
|
|||||||
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
|
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
feature_extractor = get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config)
|
feature_extractor = get_tiny_feature_extractor_from_checkpoint(
|
||||||
|
checkpoint, tiny_config, feature_extractor_class
|
||||||
|
)
|
||||||
|
|
||||||
if tokenizer is None and feature_extractor is None:
|
if tokenizer is None and feature_extractor is None:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
@ -218,6 +229,13 @@ class PipelineTestCaseMeta(type):
|
|||||||
if not tokenizer_classes:
|
if not tokenizer_classes:
|
||||||
# We need to test even if there are no tokenizers.
|
# We need to test even if there are no tokenizers.
|
||||||
tokenizer_classes = [None]
|
tokenizer_classes = [None]
|
||||||
|
else:
|
||||||
|
# Remove the non defined tokenizers
|
||||||
|
# ByT5 and Perceiver are bytes-level and don't define
|
||||||
|
# FastTokenizer, we can just ignore those.
|
||||||
|
tokenizer_classes = [
|
||||||
|
tokenizer_class for tokenizer_class in tokenizer_classes if tokenizer_class is not None
|
||||||
|
]
|
||||||
|
|
||||||
for tokenizer_class in tokenizer_classes:
|
for tokenizer_class in tokenizer_classes:
|
||||||
if tokenizer_class is not None:
|
if tokenizer_class is not None:
|
||||||
|
@ -14,12 +14,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import (
|
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
|
||||||
PerceiverConfig,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
is_vision_available,
|
|
||||||
)
|
|
||||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
@ -28,6 +23,7 @@ from transformers.testing_utils import (
|
|||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
|
slow,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
@ -50,12 +46,7 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
if isinstance(model.config, PerceiverConfig):
|
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||||
self.skipTest(
|
|
||||||
"Perceiver model tester is defined with a language one, which has no feature_extractor, so the automated test cannot work here"
|
|
||||||
)
|
|
||||||
|
|
||||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
|
|
||||||
examples = [
|
examples = [
|
||||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
@ -167,3 +158,48 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
image_classifier = pipeline("image-classification", model="lysandre/tiny-vit-random", tokenizer=tokenizer)
|
image_classifier = pipeline("image-classification", model="lysandre/tiny-vit-random", tokenizer=tokenizer)
|
||||||
|
|
||||||
self.assertIs(image_classifier.tokenizer, tokenizer)
|
self.assertIs(image_classifier.tokenizer, tokenizer)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_perceiver(self):
|
||||||
|
# Perceiver is not tested by `run_pipeline_test` properly.
|
||||||
|
# That is because the type of feature_extractor and model preprocessor need to be kept
|
||||||
|
# in sync, which is not the case in the current design
|
||||||
|
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-conv")
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.4385, "label": "tabby, tabby cat"},
|
||||||
|
{"score": 0.321, "label": "tiger cat"},
|
||||||
|
{"score": 0.0502, "label": "Egyptian cat"},
|
||||||
|
{"score": 0.0137, "label": "crib, cot"},
|
||||||
|
{"score": 0.007, "label": "radiator"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-fourier")
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.5658, "label": "tabby, tabby cat"},
|
||||||
|
{"score": 0.1309, "label": "tiger cat"},
|
||||||
|
{"score": 0.0722, "label": "Egyptian cat"},
|
||||||
|
{"score": 0.0707, "label": "remote control, remote"},
|
||||||
|
{"score": 0.0082, "label": "computer keyboard, keypad"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-learned")
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.3022, "label": "tabby, tabby cat"},
|
||||||
|
{"score": 0.2362, "label": "Egyptian cat"},
|
||||||
|
{"score": 0.1856, "label": "tiger cat"},
|
||||||
|
{"score": 0.0324, "label": "remote control, remote"},
|
||||||
|
{"score": 0.0096, "label": "quilt, comforter, comfort, puff"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user