Pipeline doc (#3055)

* Pipeline doc initial commit

* pipeline abstraction

* Remove modelcard argument from pipeline

* Task-specific pipelines can be instantiated with no model or tokenizer

* All pipelines doc
This commit is contained in:
Lysandre Debut 2020-03-02 14:07:10 -05:00 committed by GitHub
parent 2c7749784c
commit d3eb7d23a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 432 additions and 44 deletions

View File

@ -80,6 +80,7 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
main_classes/configuration
main_classes/model
main_classes/tokenizer
main_classes/pipelines
main_classes/optimizer_schedules
main_classes/processors

View File

@ -0,0 +1,63 @@
Pipelines
----------------------------------------------------
The pipelines are a great and easy way to use models for inference. These pipelines are objects that abstract most
of the complex code from the library, offering a simple API dedicated to several tasks, including Named Entity
Recognition, Masked Language Modeling, Sentiment Analysis, Feature Extraction and Question Answering.
There are two categories of pipeline abstractions to be aware about:
- The :class:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines
- The other task-specific pipelines, such as :class:`~transformers.NerPipeline`
or :class:`~transformers.QuestionAnsweringPipeline`
The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any
other pipeline but requires an additional argument which is the `task`.
.. autoclass:: transformers.pipeline
:members:
The task specific pipelines
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Parent class: Pipeline
=========================================
.. autoclass:: transformers.Pipeline
:members: predict, transform, save_pretrained
NerPipeline
==========================================
.. autoclass:: transformers.NerPipeline
TokenClassificationPipeline
==========================================
This class is an alias of the :class:`~transformers.NerPipeline` defined above. Please refer to that pipeline for
documentation and usage examples.
FillMaskPipeline
==========================================
.. autoclass:: transformers.FillMaskPipeline
FeatureExtractionPipeline
==========================================
.. autoclass:: transformers.FeatureExtractionPipeline
TextClassificationPipeline
==========================================
.. autoclass:: transformers.TextClassificationPipeline
QuestionAnsweringPipeline
==========================================
.. autoclass:: transformers.QuestionAnsweringPipeline

View File

@ -279,6 +279,9 @@ class _ScikitCompat(ABC):
class Pipeline(_ScikitCompat):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
different pipelines.
Base class implementing pipelined operations.
Pipeline workflow is defined as a sequence of the following operations:
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
@ -292,39 +295,49 @@ class Pipeline(_ScikitCompat):
pickle format.
Arguments:
**model**: ``(str, PretrainedModel, TFPretrainedModel)``:
Reference to the model to use through this pipeline.
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
**tokenizer**: ``(str, PreTrainedTokenizer)``:
Reference to the tokenizer to use through this pipeline.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
**args_parser**: ``ArgumentHandler``:
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
**device**: ``int``:
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
**binary_output** ``bool`` (default: False):
binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e. pickle) or as raw text.
Return:
:obj:`List` or :obj:`Dict`:
Pipeline returns list or dictionary depending on:
- Does the user provided multiple sample
- The pipeline expose multiple fields in the output object
Examples:
nlp = pipeline('ner')
nlp = pipeline('ner', model='...', config='...', tokenizer='...')
nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
- Whether the user supplied multiple samples
- Whether the pipeline exposes multiple fields in the output object
"""
default_input_names = None
task = None
def __init__(
self,
model,
model: Optional = None,
tokenizer: PreTrainedTokenizer = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
@ -336,6 +349,8 @@ class Pipeline(_ScikitCompat):
if framework is None:
framework = get_framework()
model, tokenizer = self.get_defaults(model, tokenizer, framework)
self.model = model
self.tokenizer = tokenizer
self.modelcard = modelcard
@ -467,15 +482,74 @@ class Pipeline(_ScikitCompat):
else:
return predictions.numpy()
def get_defaults(self, model, tokenizer, framework):
task_defaults = SUPPORTED_TASKS[self.task]
if model is None:
if framework == "tf":
model = task_defaults["tf"].from_pretrained(task_defaults["default"]["model"]["tf"])
elif framework == "pt":
model = task_defaults["pt"].from_pretrained(task_defaults["default"]["model"]["pt"])
else:
raise ValueError("Provided framework should be either 'tf' for TensorFlow or 'pt' for PyTorch.")
if tokenizer is None:
default_tokenizer = task_defaults["default"]["tokenizer"]
if isinstance(default_tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer[0], **default_tokenizer[1])
else:
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer)
return model, tokenizer
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using Model head.
Feature extraction pipeline using Model head. This pipeline extracts the hidden states from the base transformer,
which can be used as features in a downstream tasks.
This feature extraction pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "feature-extraction", for extracting features of a sequence.
All models may be used for this pipeline. See a list of all models, including community-contributed models on
`huggingface.co/models <https://huggingface.co/models>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
task = "feature-extraction"
def __init__(
self,
model,
model: Optional = None,
tokenizer: PreTrainedTokenizer = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
@ -498,9 +572,49 @@ class FeatureExtractionPipeline(Pipeline):
class TextClassificationPipeline(Pipeline):
"""
Text classification pipeline using ModelForTextClassification head.
Text classification pipeline using ModelForSequenceClassification head. See the
`sequence classification usage <../usage.html#sequence-classification>`__ examples for more information.
This text classification pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "sentiment-analysis", for classifying sequences according to positive or negative sentiments.
The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
See the list of available community models fine-tuned on such a task on
`huggingface.co/models <https://huggingface.co/models?search=&filter=text-classification>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
task = "sentiment-analysis"
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
@ -509,12 +623,53 @@ class TextClassificationPipeline(Pipeline):
class FillMaskPipeline(Pipeline):
"""
Masked language modeling prediction pipeline using ModelWithLMHead head.
Masked language modeling prediction pipeline using ModelWithLMHead head. See the
`masked language modeling usage <../usage.html#masked-language-modeling>`__ examples for more information.
This mask filling pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "fill-mask", for predicting masked tokens in a sequence.
The models that this pipeline can use are models that have been trained with a masked language modeling objective,
which includes the bi-directional models in the library.
See the list of available community models on
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
task = "fill-mask"
def __init__(
self,
model,
model: Optional = None,
tokenizer: PreTrainedTokenizer = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
@ -574,14 +729,57 @@ class FillMaskPipeline(Pipeline):
class NerPipeline(Pipeline):
"""
Named Entity Recognition pipeline using ModelForTokenClassification head.
Named Entity Recognition pipeline using ModelForTokenClassification head. See the
`named entity recognition usage <../usage.html#named-entity-recognition>`__ examples for more information.
This token recognition pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "ner", for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous.
The models that this pipeline can use are models that have been fine-tuned on a token classification task.
See the list of available community models fine-tuned on such a task on
`huggingface.co/models <https://huggingface.co/models?search=&filter=token-classification>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
Example::
from transformers import pi
"""
default_input_names = "sequences"
task = "ner"
def __init__(
self,
model,
model: Optional = None,
tokenizer: PreTrainedTokenizer = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
@ -716,15 +914,54 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
class QuestionAnsweringPipeline(Pipeline):
"""
Question Answering pipeline using ModelForQuestionAnswering head.
Question Answering pipeline using ModelForQuestionAnswering head. See the
`question answering usage <../usage.html#question-answering>`__ examples for more information.
This question answering can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):
- "question-answering", for answering questions given a context.
The models that this pipeline can use are models that have been fine-tuned on a question answering task.
See the list of available community models fine-tuned on such a task on
`huggingface.co/models <https://huggingface.co/models?search=&filter=question-answering>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
default_input_names = "question,context"
task = "question-answering"
def __init__(
self,
model,
tokenizer: Optional[PreTrainedTokenizer],
model: Optional = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
device: int = -1,
@ -1003,23 +1240,77 @@ def pipeline(
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
framework: Optional[str] = None,
**kwargs
) -> Pipeline:
"""
Utility factory method to build a pipeline.
Pipeline are made of:
A Tokenizer instance in charge of mapping raw textual input to token
A Model instance
Some (optional) post processing for enhancing model's output
Examples:
Pipeline are made of:
- A Tokenizer instance in charge of mapping raw textual input to token
- A Model instance
- Some (optional) post processing for enhancing model's output
Args:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- "feature-extraction": will return a :class:`~transformers.FeatureExtractionPipeline`
- "sentiment-analysis": will return a :class:`~transformers.TextClassificationPipeline`
- "ner": will return a :class:`~transformers.NerPipeline`
- "question-answering": will return a :class:`~transformers.QuestionAnsweringPipeline`
- "fill-mask": will return a :class:`~transformers.FillMaskPipeline`
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`, defaults to :obj:`None`):
The configuration that will be used by the pipeline to instantiate the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained model configuration inheriting from
:class:`~transformers.PretrainedConfig`.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
Returns:
:class:`~transformers.Pipeline`: Class inheriting from :class:`~transformers.Pipeline`, according to
the task.
Examples::
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
# Sentiment analysis pipeline
pipeline('sentiment-analysis')
# Question answering pipeline, specifying the checkpoint identifier
pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...)
pipeline('ner', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='bert-base-cased')
pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased')
# Named entity recognition pipeline, passing in a specific model and tokenizer
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
pipeline('ner', model=model, tokenizer=tokenizer)
# Named entity recognition pipeline, passing a model and configuration with a HTTPS URL.
model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/bert-large-cased-finetuned-conll03-english/pytorch_model.bin"
config_url = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/bert-large-cased-finetuned-conll03-english/config.json"
pipeline('ner', model=model_url, config=config_url, tokenizer='bert-base-cased')
"""
# Retrieve the task
if task not in SUPPORTED_TASKS:
@ -1048,13 +1339,12 @@ def pipeline(
"Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer."
)
modelcard = None
# Try to infer modelcard from model or config name (if provided as str)
if modelcard is None:
# Try to fallback on one of the provided string for model or config (will replace the suffix)
if isinstance(model, str):
modelcard = model
elif isinstance(config, str):
modelcard = config
if isinstance(model, str):
modelcard = model
elif isinstance(config, str):
modelcard = config
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):

View File

@ -2,9 +2,16 @@ import unittest
from typing import Iterable, List, Optional
from transformers import pipeline
from transformers.pipelines import Pipeline
from transformers.pipelines import (
FeatureExtractionPipeline,
FillMaskPipeline,
NerPipeline,
Pipeline,
QuestionAnsweringPipeline,
TextClassificationPipeline,
)
from .utils import require_tf, require_torch
from .utils import require_tf, require_torch, slow
QA_FINETUNED_MODELS = [
@ -304,3 +311,30 @@ class MultiColumnInputTestCase(unittest.TestCase):
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
class PipelineCommonTests(unittest.TestCase):
pipelines = (
NerPipeline,
FeatureExtractionPipeline,
QuestionAnsweringPipeline,
FillMaskPipeline,
TextClassificationPipeline,
)
@slow
@require_tf
def test_tf_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="tf")
@slow
@require_torch
def test_pt_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="pt")