Pipeline update & tests (#12207)

This commit is contained in:
Lysandre Debut 2021-06-17 09:41:16 +02:00 committed by GitHub
parent 700cee3446
commit b56848c8c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 1 deletions

View File

@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline):
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline):
images = [self.load_image(image) for image in images]
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)

View File

@ -15,6 +15,7 @@
import unittest
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForImageClassification,
PreTrainedTokenizer,
@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer)
def test_num_labels_inferior_to_topk(self):
for small_model in self.small_models:
num_labels = 2
model = AutoModelForImageClassification.from_config(
AutoConfig.from_pretrained(small_model, num_labels=num_labels)
)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), num_labels)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)