Add the ImageClassificationPipeline (#11598)

* Add the ImageClassificationPipeline

* Code review

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>

* Have `load_image` at the module level

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Lysandre Debut 2021-05-07 14:08:40 +02:00 committed by GitHub
parent e7bff0aabe
commit 39084ca663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 428 additions and 74 deletions

View File

@ -37,6 +37,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install --upgrade pip pip install --upgrade pip
sudo apt -y update && sudo apt install -y libsndfile1-dev
pip install .[dev] pip install .[dev]
- name: Create model files - name: Create model files
run: | run: |

View File

@ -36,6 +36,7 @@ There are two categories of pipeline abstractions to be aware about:
- :class:`~transformers.ZeroShotClassificationPipeline` - :class:`~transformers.ZeroShotClassificationPipeline`
- :class:`~transformers.Text2TextGenerationPipeline` - :class:`~transformers.Text2TextGenerationPipeline`
- :class:`~transformers.TableQuestionAnsweringPipeline` - :class:`~transformers.TableQuestionAnsweringPipeline`
- :class:`~transformers.ImageClassificationPipeline`
The pipeline abstraction The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -79,6 +80,13 @@ FillMaskPipeline
:special-members: __call__ :special-members: __call__
:members: :members:
ImageClassificationPipeline
=======================================================================================================================
.. autoclass:: transformers.ImageClassificationPipeline
:special-members: __call__
:members:
NerPipeline NerPipeline
======================================================================================================================= =======================================================================================================================

View File

@ -128,6 +128,13 @@ AutoModelForTableQuestionAnswering
:members: :members:
AutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForImageClassification
:members:
TFAutoModel TFAutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -244,6 +244,7 @@ _import_structure = {
"CsvPipelineDataFormat", "CsvPipelineDataFormat",
"FeatureExtractionPipeline", "FeatureExtractionPipeline",
"FillMaskPipeline", "FillMaskPipeline",
"ImageClassificationPipeline",
"JsonPipelineDataFormat", "JsonPipelineDataFormat",
"NerPipeline", "NerPipeline",
"PipedPipelineDataFormat", "PipedPipelineDataFormat",
@ -483,6 +484,7 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForImageClassification",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMultipleChoice", "AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction", "AutoModelForNextSentencePrediction",
@ -1640,6 +1642,7 @@ if TYPE_CHECKING:
CsvPipelineDataFormat, CsvPipelineDataFormat,
FeatureExtractionPipeline, FeatureExtractionPipeline,
FillMaskPipeline, FillMaskPipeline,
ImageClassificationPipeline,
JsonPipelineDataFormat, JsonPipelineDataFormat,
NerPipeline, NerPipeline,
PipedPipelineDataFormat, PipedPipelineDataFormat,
@ -1845,6 +1848,7 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction, AutoModelForNextSentencePrediction,

View File

@ -226,7 +226,7 @@ class FeatureExtractionMixin:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used. standard cache should not be used.

View File

@ -14,34 +14,26 @@
# limitations under the License. # limitations under the License.
""" AutoFeatureExtractor class. """ """ AutoFeatureExtractor class. """
import os
from collections import OrderedDict from collections import OrderedDict
from transformers import DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ... import DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import is_speech_available, is_vision_available
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import replace_list_option_in_docstrings
if is_speech_available():
from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor
else:
Speech2TextFeatureExtractor = None
if is_vision_available():
from ..deit.feature_extraction_deit import DeiTFeatureExtractor
from ..vit.feature_extraction_vit import ViTFeatureExtractor
else:
DeiTFeatureExtractor = None
ViTFeatureExtractor = None
# Build the list of all feature extractors # Build the list of all feature extractors
from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
FEATURE_EXTRACTOR_MAPPING = OrderedDict( FEATURE_EXTRACTOR_MAPPING = OrderedDict(
[ [
("deit", DeiTFeatureExtractor), (DeiTConfig, DeiTFeatureExtractor),
("s2t", Speech2TextFeatureExtractor), (Speech2TextConfig, Speech2TextFeatureExtractor),
("vit", ViTFeatureExtractor), (ViTConfig, ViTFeatureExtractor),
("wav2vec2", Wav2Vec2FeatureExtractor), (Wav2Vec2Config, Wav2Vec2FeatureExtractor),
] ]
) )
@ -89,7 +81,7 @@ class AutoFeatureExtractor:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used. standard cache should not be used.
@ -134,20 +126,29 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/') >>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
""" """
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)
if not is_feature_extraction_file and not is_directory:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "feature_extractor_type" in config_dict: if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs) return feature_extractor_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
"its feature_extraction_config.json, or contain one of the following strings " f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}" f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
) )

View File

@ -97,7 +97,7 @@ class Speech2TextProcessor:
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
**kwargs **kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`

View File

@ -96,7 +96,7 @@ class Wav2Vec2Processor:
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
**kwargs **kwargs
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`

View File

@ -20,9 +20,12 @@ import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available from ..file_utils import is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.tokenization_auto import AutoTokenizer from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging from ..utils import logging
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
@ -40,6 +43,7 @@ from .base import (
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
@ -79,6 +83,7 @@ if is_torch_available():
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
@ -198,6 +203,12 @@ SUPPORTED_TASKS = {
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
}, },
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": None,
"pt": AutoModelForImageClassification if is_torch_available() else None,
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
} }
@ -252,6 +263,7 @@ def pipeline(
model: Optional = None, model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
use_fast: bool = True, use_fast: bool = True,
@ -309,6 +321,18 @@ def pipeline(
:obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if :obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if
it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer
for the given :obj:`task` will be loaded. for the given :obj:`task` will be loaded.
feature_extractor (:obj:`str` or :obj:`~transformers.PreTrainedFeatureExtractor`, `optional`):
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
identifier or an actual pretrained feature extractor inheriting from
:class:`~transformers.PreTrainedFeatureExtractor`.
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
models. Multi-modal models will also require a tokenizer to be passed.
If not provided, the default feature extractor for the given :obj:`model` will be loaded (if it is a
string). If :obj:`model` is not specified or not a string, then the default feature extractor for
:obj:`config` is loaded (if it is a string). However, if :obj:`config` is also not given or not a string,
then the default feature extractor for the given :obj:`task` will be loaded.
framework (:obj:`str`, `optional`): framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed. must be installed.
@ -359,19 +383,7 @@ def pipeline(
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model = get_default_model(targeted_task, framework, task_options)
# Try to infer tokenizer from model or config name (if provided as str) model_name = model if isinstance(model, str) else None
if tokenizer is None:
if isinstance(model, str):
tokenizer = model
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guest what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
modelcard = None modelcard = None
# Try to infer modelcard from model or config name (if provided as str) # Try to infer modelcard from model or config name (if provided as str)
if isinstance(model, str): if isinstance(model, str):
@ -388,19 +400,6 @@ def pipeline(
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer[0], use_fast=use_fast, revision=revision, _from_pipeline=task, **tokenizer[1]
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
)
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
@ -434,6 +433,61 @@ def pipeline(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
) )
model_config = model.config
load_tokenizer = type(model_config) in TOKENIZER_MAPPING
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING
if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model_name, str):
tokenizer = model_name
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer_identifier = tokenizer[0]
tokenizer_kwargs = tokenizer[1]
else:
tokenizer_identifier = tokenizer
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
)
if load_feature_extractor:
# Try to infer feature extractor from model or config name (if provided as str)
if feature_extractor is None:
if isinstance(model_name, str):
feature_extractor = model_name
elif isinstance(config, str):
feature_extractor = config
else:
# Impossible to guess what is the right feature_extractor here
raise Exception(
"Impossible to guess which feature extractor to use. "
"Please provide a PreTrainedFeatureExtractor class or a path/identifier "
"to a pretrained feature extractor."
)
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:
if key.startswith("translation"): if key.startswith("translation"):
@ -444,4 +498,16 @@ def pipeline(
) )
break break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs) if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(
model=model,
modelcard=modelcard,
framework=framework,
task=task,
**kwargs,
)

View File

@ -23,6 +23,7 @@ from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
@ -522,7 +523,8 @@ class Pipeline(_ScikitCompat):
def __init__( def __init__(
self, self,
model: Union["PreTrainedModel", "TFPreTrainedModel"], model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer, tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
task: str = "", task: str = "",
@ -537,6 +539,7 @@ class Pipeline(_ScikitCompat):
self.task = task self.task = task
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
@ -565,7 +568,13 @@ class Pipeline(_ScikitCompat):
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
self.model.save_pretrained(save_directory) self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)
if self.modelcard is not None: if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory) self.modelcard.save_pretrained(save_directory)
@ -630,7 +639,14 @@ class Pipeline(_ScikitCompat):
The list of models supported by the pipeline, or a dictionary with model class values. The list of models supported by the pipeline, or a dictionary with model class values.
""" """
if not isinstance(supported_models, list): # Create from a model mapping if not isinstance(supported_models, list): # Create from a model mapping
supported_models = [item[1].__name__ for item in supported_models.items()] supported_models_names = []
for config, model in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models: if self.model.__class__.__name__ not in supported_models:
raise PipelineException( raise PipelineException(
self.task, self.task,

View File

@ -0,0 +1,129 @@
import os
from typing import TYPE_CHECKING, List, Optional, Union
import requests
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ImageClassificationPipeline(Pipeline):
"""
Image classification pipeline using any :obj:`AutoModelForImageClassification`. This pipeline predicts the class of
an image.
This image classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
task identifier: :obj:`"image-classification"`.
See the list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=image-classification>`__.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.feature_extractor = feature_extractor
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
return Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
return Image.open(image)
elif isinstance(image, Image.Image):
return image
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], top_k=5):
"""
Assign labels to the image(s) passed as inputs.
Args:
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
the images.
The dictionaries contain the following keys:
- **label** (:obj:`str`) -- The label identified by the model.
- **score** (:obj:`int`) -- The score attributed by the model for that label.
"""
is_batched = isinstance(images, list)
if not is_batched:
images = [images]
images = [self.load_image(image) for image in images]
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
probs = outputs.logits.softmax(-1)
scores, ids = probs.topk(top_k)
scores = scores.tolist()
ids = ids.tolist()
if not is_batched:
scores, ids = scores[0], ids[0]
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
else:
labels = []
for scores, ids in zip(scores, ids):
labels.append(
[{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
)
return labels

View File

@ -376,6 +376,15 @@ class AutoModelForCausalLM:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedLM: class AutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])

BIN
tests/fixtures/coco.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

View File

@ -0,0 +1,3 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
}

View File

@ -16,9 +16,10 @@
import os import os
import unittest import unittest
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
) )
@ -29,16 +30,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self): def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))

View File

@ -0,0 +1,115 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import (
AutoFeatureExtractor,
AutoModelForImageClassification,
PreTrainedTokenizer,
is_vision_available,
)
from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import require_torch, require_vision
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@require_vision
@require_torch
class ImageClassificationPipelineTests(unittest.TestCase):
pipeline_task = "image-classification"
small_models = ["lysandre/tiny-vit-random"] # Models tested without the @slow decorator
valid_inputs = [
{"images": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{
"images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
},
{"images": "tests/fixtures/coco.jpg"},
{"images": ["tests/fixtures/coco.jpg", "tests/fixtures/coco.jpg"]},
{"images": Image.open("tests/fixtures/coco.jpg")},
{"images": [Image.open("tests/fixtures/coco.jpg"), Image.open("tests/fixtures/coco.jpg")]},
{"images": [Image.open("tests/fixtures/coco.jpg"), "tests/fixtures/coco.jpg"]},
]
def test_small_model_from_factory(self):
for small_model in self.small_models:
image_classifier = pipeline("image-classification", model=small_model)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
assert_valid_pipeline_output(output)
def test_small_model_from_pipeline(self):
for small_model in self.small_models:
model = AutoModelForImageClassification.from_pretrained(small_model)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)
def test_custom_tokenizer(self):
tokenizer = PreTrainedTokenizer()
# Assert that the pipeline can be initialized with a feature extractor that is not in any mapping
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer)