mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Enable multi-label image classification in pipeline (#28433)
Enable multi-label image classification
This commit is contained in:
parent
8205b2647c
commit
66964c00f6
@ -1,6 +1,9 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import (
|
||||
ExplicitEnum,
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@ -17,10 +20,7 @@ if is_vision_available():
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
from ..tf_utils import stable_softmax
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
@ -28,7 +28,38 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
# Copied from transformers.pipelines.text_classification.sigmoid
|
||||
def sigmoid(_outputs):
|
||||
return 1.0 / (1.0 + np.exp(-_outputs))
|
||||
|
||||
|
||||
# Copied from transformers.pipelines.text_classification.softmax
|
||||
def softmax(_outputs):
|
||||
maxes = np.max(_outputs, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(_outputs - maxes)
|
||||
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
# Copied from transformers.pipelines.text_classification.ClassificationFunction
|
||||
class ClassificationFunction(ExplicitEnum):
|
||||
SIGMOID = "sigmoid"
|
||||
SOFTMAX = "softmax"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@add_end_docstrings(
|
||||
PIPELINE_INIT_ARGS,
|
||||
r"""
|
||||
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
||||
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
|
||||
|
||||
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
|
||||
has several labels, will apply the softmax function on the output.
|
||||
- `"sigmoid"`: Applies the sigmoid function on the output.
|
||||
- `"softmax"`: Applies the softmax function on the output.
|
||||
- `"none"`: Does not apply any function on the output.
|
||||
""",
|
||||
)
|
||||
class ImageClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
|
||||
@ -53,6 +84,8 @@ class ImageClassificationPipeline(Pipeline):
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=image-classification).
|
||||
"""
|
||||
|
||||
function_to_apply: ClassificationFunction = ClassificationFunction.NONE
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "vision")
|
||||
@ -62,13 +95,17 @@ class ImageClassificationPipeline(Pipeline):
|
||||
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, timeout=None):
|
||||
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
|
||||
preprocess_params = {}
|
||||
if timeout is not None:
|
||||
preprocess_params["timeout"] = timeout
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
if isinstance(function_to_apply, str):
|
||||
function_to_apply = ClassificationFunction(function_to_apply.lower())
|
||||
if function_to_apply is not None:
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
@ -86,6 +123,21 @@ class ImageClassificationPipeline(Pipeline):
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
||||
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
|
||||
values:
|
||||
|
||||
If this argument is not specified, then it will apply the following functions according to the number
|
||||
of labels:
|
||||
|
||||
- If the model has a single label, will apply the sigmoid function on the output.
|
||||
- If the model has several labels, will apply the softmax function on the output.
|
||||
|
||||
Possible values are:
|
||||
|
||||
- `"sigmoid"`: Applies the sigmoid function on the output.
|
||||
- `"softmax"`: Applies the softmax function on the output.
|
||||
- `"none"`: Does not apply any function on the output.
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
@ -114,20 +166,37 @@ class ImageClassificationPipeline(Pipeline):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
|
||||
if function_to_apply is None:
|
||||
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
|
||||
function_to_apply = ClassificationFunction.SIGMOID
|
||||
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1:
|
||||
function_to_apply = ClassificationFunction.SOFTMAX
|
||||
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
|
||||
function_to_apply = self.model.config.function_to_apply
|
||||
else:
|
||||
function_to_apply = ClassificationFunction.NONE
|
||||
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
if self.framework == "pt":
|
||||
probs = model_outputs.logits.softmax(-1)[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
elif self.framework == "tf":
|
||||
probs = stable_softmax(model_outputs.logits, axis=-1)[0]
|
||||
topk = tf.math.top_k(probs, k=top_k)
|
||||
scores, ids = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
outputs = model_outputs["logits"][0]
|
||||
outputs = outputs.numpy()
|
||||
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||
scores = sigmoid(outputs)
|
||||
elif function_to_apply == ClassificationFunction.SOFTMAX:
|
||||
scores = softmax(outputs)
|
||||
elif function_to_apply == ClassificationFunction.NONE:
|
||||
scores = outputs
|
||||
else:
|
||||
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
|
||||
|
||||
dict_scores = [
|
||||
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
|
||||
]
|
||||
dict_scores.sort(key=lambda x: x["score"], reverse=True)
|
||||
if top_k is not None:
|
||||
dict_scores = dict_scores[:top_k]
|
||||
|
||||
return dict_scores
|
||||
|
@ -221,3 +221,49 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
{"score": 0.0096, "label": "quilt, comforter, comfort, puff"},
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_multilabel_classification(self):
|
||||
small_model = "hf-internal-testing/tiny-random-vit"
|
||||
|
||||
# Sigmoid is applied for multi-label classification
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
image_classifier.model.config.problem_type = "multi_label_classification"
|
||||
|
||||
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
)
|
||||
|
||||
outputs = image_classifier(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_function_to_apply(self):
|
||||
small_model = "hf-internal-testing/tiny-random-vit"
|
||||
|
||||
# Sigmoid is applied for multi-label classification
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
|
||||
outputs = image_classifier(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
function_to_apply="sigmoid",
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user