mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Change back pipeline signatures (#3105)
* Change back pipeline signatures * String types for non-imported objects
This commit is contained in:
parent
d6df9a8ffe
commit
0ae91c80aa
@ -296,19 +296,13 @@ class Pipeline(_ScikitCompat):
|
|||||||
pickle format.
|
pickle format.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -334,12 +328,11 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
default_input_names = None
|
default_input_names = None
|
||||||
task = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional = None,
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
tokenizer: PreTrainedTokenizer = None,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
@ -350,8 +343,6 @@ class Pipeline(_ScikitCompat):
|
|||||||
if framework is None:
|
if framework is None:
|
||||||
framework = get_framework()
|
framework = get_framework()
|
||||||
|
|
||||||
model, tokenizer = self.get_defaults(model, tokenizer, framework)
|
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.modelcard = modelcard
|
self.modelcard = modelcard
|
||||||
@ -483,26 +474,6 @@ class Pipeline(_ScikitCompat):
|
|||||||
else:
|
else:
|
||||||
return predictions.numpy()
|
return predictions.numpy()
|
||||||
|
|
||||||
def get_defaults(self, model, tokenizer, framework):
|
|
||||||
task_defaults = SUPPORTED_TASKS[self.task]
|
|
||||||
if model is None:
|
|
||||||
if framework == "tf":
|
|
||||||
model = task_defaults["tf"].from_pretrained(task_defaults["default"]["model"]["tf"])
|
|
||||||
elif framework == "pt":
|
|
||||||
model = task_defaults["pt"].from_pretrained(task_defaults["default"]["model"]["pt"])
|
|
||||||
else:
|
|
||||||
raise ValueError("Provided framework should be either 'tf' for TensorFlow or 'pt' for PyTorch.")
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
default_tokenizer = task_defaults["default"]["tokenizer"]
|
|
||||||
if isinstance(default_tokenizer, tuple):
|
|
||||||
# For tuple we have (tokenizer name, {kwargs})
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer[0], **default_tokenizer[1])
|
|
||||||
else:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(default_tokenizer)
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractionPipeline(Pipeline):
|
class FeatureExtractionPipeline(Pipeline):
|
||||||
"""
|
"""
|
||||||
@ -518,19 +489,13 @@ class FeatureExtractionPipeline(Pipeline):
|
|||||||
`huggingface.co/models <https://huggingface.co/models>`__.
|
`huggingface.co/models <https://huggingface.co/models>`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -546,12 +511,10 @@ class FeatureExtractionPipeline(Pipeline):
|
|||||||
on the associated CUDA device id.
|
on the associated CUDA device id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = "feature-extraction"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional = None,
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
tokenizer: PreTrainedTokenizer = None,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
@ -586,19 +549,13 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=text-classification>`__.
|
`huggingface.co/models <https://huggingface.co/models?search=&filter=text-classification>`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -614,8 +571,6 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
on the associated CUDA device id.
|
on the associated CUDA device id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = "sentiment-analysis"
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
outputs = super().__call__(*args, **kwargs)
|
outputs = super().__call__(*args, **kwargs)
|
||||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
|
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
|
||||||
@ -638,19 +593,13 @@ class FillMaskPipeline(Pipeline):
|
|||||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
|
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -666,12 +615,10 @@ class FillMaskPipeline(Pipeline):
|
|||||||
on the associated CUDA device id.
|
on the associated CUDA device id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = "fill-mask"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional = None,
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
tokenizer: PreTrainedTokenizer = None,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
@ -743,19 +690,13 @@ class NerPipeline(Pipeline):
|
|||||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=token-classification>`__.
|
`huggingface.co/models <https://huggingface.co/models?search=&filter=token-classification>`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -769,19 +710,14 @@ class NerPipeline(Pipeline):
|
|||||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||||
on the associated CUDA device id.
|
on the associated CUDA device id.
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
from transformers import pi
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_input_names = "sequences"
|
default_input_names = "sequences"
|
||||||
task = "ner"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional = None,
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
tokenizer: PreTrainedTokenizer = None,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
@ -928,19 +864,13 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=question-answering>`__.
|
`huggingface.co/models <https://huggingface.co/models?search=&filter=question-answering>`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
checkpoint identifier or an actual pre-trained model inheriting from
|
|
||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
TensorFlow.
|
TensorFlow.
|
||||||
|
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
|
||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
|
||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
|
||||||
:class:`~transformers.PreTrainedTokenizer`.
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
If :obj:`None`, the default of the pipeline will be loaded.
|
|
||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
Model card attributed to the model for this pipeline.
|
Model card attributed to the model for this pipeline.
|
||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
@ -957,12 +887,11 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
default_input_names = "question,context"
|
default_input_names = "question,context"
|
||||||
task = "question-answering"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional = None,
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
device: int = -1,
|
device: int = -1,
|
||||||
|
Loading…
Reference in New Issue
Block a user