mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline doc (#3055)
* Pipeline doc initial commit * pipeline abstraction * Remove modelcard argument from pipeline * Task-specific pipelines can be instantiated with no model or tokenizer * All pipelines doc
This commit is contained in:
parent
2c7749784c
commit
d3eb7d23a4
@ -80,6 +80,7 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
|
||||
main_classes/configuration
|
||||
main_classes/model
|
||||
main_classes/tokenizer
|
||||
main_classes/pipelines
|
||||
main_classes/optimizer_schedules
|
||||
main_classes/processors
|
||||
|
||||
|
63
docs/source/main_classes/pipelines.rst
Normal file
63
docs/source/main_classes/pipelines.rst
Normal file
@ -0,0 +1,63 @@
|
||||
Pipelines
|
||||
----------------------------------------------------
|
||||
|
||||
The pipelines are a great and easy way to use models for inference. These pipelines are objects that abstract most
|
||||
of the complex code from the library, offering a simple API dedicated to several tasks, including Named Entity
|
||||
Recognition, Masked Language Modeling, Sentiment Analysis, Feature Extraction and Question Answering.
|
||||
|
||||
There are two categories of pipeline abstractions to be aware about:
|
||||
|
||||
- The :class:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines
|
||||
- The other task-specific pipelines, such as :class:`~transformers.NerPipeline`
|
||||
or :class:`~transformers.QuestionAnsweringPipeline`
|
||||
|
||||
The pipeline abstraction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any
|
||||
other pipeline but requires an additional argument which is the `task`.
|
||||
|
||||
.. autoclass:: transformers.pipeline
|
||||
:members:
|
||||
|
||||
|
||||
The task specific pipelines
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Parent class: Pipeline
|
||||
=========================================
|
||||
|
||||
.. autoclass:: transformers.Pipeline
|
||||
:members: predict, transform, save_pretrained
|
||||
|
||||
NerPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.NerPipeline
|
||||
|
||||
TokenClassificationPipeline
|
||||
==========================================
|
||||
|
||||
This class is an alias of the :class:`~transformers.NerPipeline` defined above. Please refer to that pipeline for
|
||||
documentation and usage examples.
|
||||
|
||||
FillMaskPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.FillMaskPipeline
|
||||
|
||||
FeatureExtractionPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.FeatureExtractionPipeline
|
||||
|
||||
TextClassificationPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.TextClassificationPipeline
|
||||
|
||||
QuestionAnsweringPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.QuestionAnsweringPipeline
|
||||
|
@ -279,6 +279,9 @@ class _ScikitCompat(ABC):
|
||||
|
||||
class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
|
||||
different pipelines.
|
||||
|
||||
Base class implementing pipelined operations.
|
||||
Pipeline workflow is defined as a sequence of the following operations:
|
||||
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
|
||||
@ -292,39 +295,49 @@ class Pipeline(_ScikitCompat):
|
||||
pickle format.
|
||||
|
||||
Arguments:
|
||||
**model**: ``(str, PretrainedModel, TFPretrainedModel)``:
|
||||
Reference to the model to use through this pipeline.
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
**tokenizer**: ``(str, PreTrainedTokenizer)``:
|
||||
Reference to the tokenizer to use through this pipeline.
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
**args_parser**: ``ArgumentHandler``:
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
|
||||
**device**: ``int``:
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
|
||||
**binary_output** ``bool`` (default: False):
|
||||
binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Flag indicating if the output the pipeline should happen in a binary format (i.e. pickle) or as raw text.
|
||||
|
||||
Return:
|
||||
:obj:`List` or :obj:`Dict`:
|
||||
Pipeline returns list or dictionary depending on:
|
||||
- Does the user provided multiple sample
|
||||
- The pipeline expose multiple fields in the output object
|
||||
|
||||
Examples:
|
||||
nlp = pipeline('ner')
|
||||
nlp = pipeline('ner', model='...', config='...', tokenizer='...')
|
||||
nlp = NerPipeline(model='...', config='...', tokenizer='...')
|
||||
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
|
||||
- Whether the user supplied multiple samples
|
||||
- Whether the pipeline exposes multiple fields in the output object
|
||||
"""
|
||||
|
||||
default_input_names = None
|
||||
task = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: Optional = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
@ -336,6 +349,8 @@ class Pipeline(_ScikitCompat):
|
||||
if framework is None:
|
||||
framework = get_framework()
|
||||
|
||||
model, tokenizer = self.get_defaults(model, tokenizer, framework)
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.modelcard = modelcard
|
||||
@ -467,15 +482,74 @@ class Pipeline(_ScikitCompat):
|
||||
else:
|
||||
return predictions.numpy()
|
||||
|
||||
def get_defaults(self, model, tokenizer, framework):
|
||||
task_defaults = SUPPORTED_TASKS[self.task]
|
||||
if model is None:
|
||||
if framework == "tf":
|
||||
model = task_defaults["tf"].from_pretrained(task_defaults["default"]["model"]["tf"])
|
||||
elif framework == "pt":
|
||||
model = task_defaults["pt"].from_pretrained(task_defaults["default"]["model"]["pt"])
|
||||
else:
|
||||
raise ValueError("Provided framework should be either 'tf' for TensorFlow or 'pt' for PyTorch.")
|
||||
|
||||
if tokenizer is None:
|
||||
default_tokenizer = task_defaults["default"]["tokenizer"]
|
||||
if isinstance(default_tokenizer, tuple):
|
||||
# For tuple we have (tokenizer name, {kwargs})
|
||||
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer[0], **default_tokenizer[1])
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class FeatureExtractionPipeline(Pipeline):
|
||||
"""
|
||||
Feature extraction pipeline using Model head.
|
||||
Feature extraction pipeline using Model head. This pipeline extracts the hidden states from the base transformer,
|
||||
which can be used as features in a downstream tasks.
|
||||
|
||||
This feature extraction pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "feature-extraction", for extracting features of a sequence.
|
||||
|
||||
All models may be used for this pipeline. See a list of all models, including community-contributed models on
|
||||
`huggingface.co/models <https://huggingface.co/models>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
task = "feature-extraction"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: Optional = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
@ -498,9 +572,49 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
|
||||
class TextClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Text classification pipeline using ModelForTextClassification head.
|
||||
Text classification pipeline using ModelForSequenceClassification head. See the
|
||||
`sequence classification usage <../usage.html#sequence-classification>`__ examples for more information.
|
||||
|
||||
This text classification pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "sentiment-analysis", for classifying sequences according to positive or negative sentiments.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
|
||||
See the list of available community models fine-tuned on such a task on
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=text-classification>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
task = "sentiment-analysis"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
|
||||
@ -509,12 +623,53 @@ class TextClassificationPipeline(Pipeline):
|
||||
|
||||
class FillMaskPipeline(Pipeline):
|
||||
"""
|
||||
Masked language modeling prediction pipeline using ModelWithLMHead head.
|
||||
Masked language modeling prediction pipeline using ModelWithLMHead head. See the
|
||||
`masked language modeling usage <../usage.html#masked-language-modeling>`__ examples for more information.
|
||||
|
||||
This mask filling pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "fill-mask", for predicting masked tokens in a sequence.
|
||||
|
||||
The models that this pipeline can use are models that have been trained with a masked language modeling objective,
|
||||
which includes the bi-directional models in the library.
|
||||
See the list of available community models on
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
task = "fill-mask"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: Optional = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
@ -574,14 +729,57 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
class NerPipeline(Pipeline):
|
||||
"""
|
||||
Named Entity Recognition pipeline using ModelForTokenClassification head.
|
||||
Named Entity Recognition pipeline using ModelForTokenClassification head. See the
|
||||
`named entity recognition usage <../usage.html#named-entity-recognition>`__ examples for more information.
|
||||
|
||||
This token recognition pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "ner", for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a token classification task.
|
||||
See the list of available community models fine-tuned on such a task on
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=token-classification>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
|
||||
Example::
|
||||
|
||||
from transformers import pi
|
||||
"""
|
||||
|
||||
default_input_names = "sequences"
|
||||
task = "ner"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: Optional = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
@ -716,15 +914,54 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
|
||||
class QuestionAnsweringPipeline(Pipeline):
|
||||
"""
|
||||
Question Answering pipeline using ModelForQuestionAnswering head.
|
||||
Question Answering pipeline using ModelForQuestionAnswering head. See the
|
||||
`question answering usage <../usage.html#question-answering>`__ examples for more information.
|
||||
|
||||
This question answering can currently be loaded from the :func:`~transformers.pipeline` method using
|
||||
the following task identifier(s):
|
||||
|
||||
- "question-answering", for answering questions given a context.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a question answering task.
|
||||
See the list of available community models fine-tuned on such a task on
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=question-answering>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
default_input_names = "question,context"
|
||||
task = "question-answering"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tokenizer: Optional[PreTrainedTokenizer],
|
||||
model: Optional = None,
|
||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: int = -1,
|
||||
@ -1003,23 +1240,77 @@ def pipeline(
|
||||
model: Optional = None,
|
||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||
modelcard: Optional[Union[str, ModelCard]] = None,
|
||||
framework: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Utility factory method to build a pipeline.
|
||||
Pipeline are made of:
|
||||
A Tokenizer instance in charge of mapping raw textual input to token
|
||||
A Model instance
|
||||
Some (optional) post processing for enhancing model's output
|
||||
|
||||
Examples:
|
||||
Pipeline are made of:
|
||||
|
||||
- A Tokenizer instance in charge of mapping raw textual input to token
|
||||
- A Model instance
|
||||
- Some (optional) post processing for enhancing model's output
|
||||
|
||||
|
||||
Args:
|
||||
task (:obj:`str`):
|
||||
The task defining which pipeline will be returned. Currently accepted tasks are:
|
||||
|
||||
- "feature-extraction": will return a :class:`~transformers.FeatureExtractionPipeline`
|
||||
- "sentiment-analysis": will return a :class:`~transformers.TextClassificationPipeline`
|
||||
- "ner": will return a :class:`~transformers.NerPipeline`
|
||||
- "question-answering": will return a :class:`~transformers.QuestionAnsweringPipeline`
|
||||
- "fill-mask": will return a :class:`~transformers.FillMaskPipeline`
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`, defaults to :obj:`None`):
|
||||
The configuration that will be used by the pipeline to instantiate the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained model configuration inheriting from
|
||||
:class:`~transformers.PretrainedConfig`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.Pipeline`: Class inheriting from :class:`~transformers.Pipeline`, according to
|
||||
the task.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
|
||||
|
||||
# Sentiment analysis pipeline
|
||||
pipeline('sentiment-analysis')
|
||||
|
||||
# Question answering pipeline, specifying the checkpoint identifier
|
||||
pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
|
||||
pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...)
|
||||
pipeline('ner', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='bert-base-cased')
|
||||
pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased')
|
||||
|
||||
# Named entity recognition pipeline, passing in a specific model and tokenizer
|
||||
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
pipeline('ner', model=model, tokenizer=tokenizer)
|
||||
|
||||
# Named entity recognition pipeline, passing a model and configuration with a HTTPS URL.
|
||||
model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/bert-large-cased-finetuned-conll03-english/pytorch_model.bin"
|
||||
config_url = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/bert-large-cased-finetuned-conll03-english/config.json"
|
||||
pipeline('ner', model=model_url, config=config_url, tokenizer='bert-base-cased')
|
||||
"""
|
||||
# Retrieve the task
|
||||
if task not in SUPPORTED_TASKS:
|
||||
@ -1048,13 +1339,12 @@ def pipeline(
|
||||
"Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer."
|
||||
)
|
||||
|
||||
modelcard = None
|
||||
# Try to infer modelcard from model or config name (if provided as str)
|
||||
if modelcard is None:
|
||||
# Try to fallback on one of the provided string for model or config (will replace the suffix)
|
||||
if isinstance(model, str):
|
||||
modelcard = model
|
||||
elif isinstance(config, str):
|
||||
modelcard = config
|
||||
if isinstance(model, str):
|
||||
modelcard = model
|
||||
elif isinstance(config, str):
|
||||
modelcard = config
|
||||
|
||||
# Instantiate tokenizer if needed
|
||||
if isinstance(tokenizer, (str, tuple)):
|
||||
|
@ -2,9 +2,16 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.pipelines import (
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
NerPipeline,
|
||||
Pipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
|
||||
from .utils import require_tf, require_torch
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = [
|
||||
@ -304,3 +311,30 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
||||
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = (
|
||||
NerPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
FillMaskPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tf_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="tf")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pt_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="pt")
|
||||
|
Loading…
Reference in New Issue
Block a user