mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline update & tests (#12207)
This commit is contained in:
parent
700cee3446
commit
b56848c8c8
@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline):
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
top_k (:obj:`int`, `optional`, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline.
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline):
|
||||
|
||||
images = [self.load_image(image) for image in images]
|
||||
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.feature_extractor(images=images, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
|
@ -15,6 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageClassification,
|
||||
PreTrainedTokenizer,
|
||||
@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
|
||||
|
||||
self.assertIs(image_classifier.tokenizer, tokenizer)
|
||||
|
||||
def test_num_labels_inferior_to_topk(self):
|
||||
for small_model in self.small_models:
|
||||
|
||||
num_labels = 2
|
||||
model = AutoModelForImageClassification.from_config(
|
||||
AutoConfig.from_pretrained(small_model, num_labels=num_labels)
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
for valid_input in self.valid_inputs:
|
||||
output = image_classifier(**valid_input)
|
||||
|
||||
def assert_valid_pipeline_output(pipeline_output):
|
||||
self.assertTrue(isinstance(pipeline_output, list))
|
||||
self.assertEqual(len(pipeline_output), num_labels)
|
||||
for label_result in pipeline_output:
|
||||
self.assertTrue(isinstance(label_result, dict))
|
||||
self.assertIn("label", label_result)
|
||||
self.assertIn("score", label_result)
|
||||
|
||||
if isinstance(valid_input["images"], list):
|
||||
# When images are batched, pipeline output is a list of lists of dictionaries
|
||||
self.assertEqual(len(valid_input["images"]), len(output))
|
||||
for individual_output in output:
|
||||
assert_valid_pipeline_output(individual_output)
|
||||
else:
|
||||
# When images are batched, pipeline output is a list of dictionaries
|
||||
assert_valid_pipeline_output(output)
|
||||
|
Loading…
Reference in New Issue
Block a user