mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00
Custom pipeline (#18079)
* Initial work * More work * Add tests for custom pipelines on the Hub * Protect import * Make the test work for TF as well * Last PyTorch specific bit * Add documentation * Style * Title in toc * Bad names! * Update docs/source/en/add_new_pipeline.mdx Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Auto stash before merge of "custom_pipeline" and "origin/custom_pipeline" * Address review comments * Address more review comments * Update src/transformers/pipelines/__init__.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
3bb6356d4d
commit
dc9147ff36
@ -102,7 +102,7 @@
|
|||||||
- local: add_new_model
|
- local: add_new_model
|
||||||
title: How to add a model to 🤗 Transformers?
|
title: How to add a model to 🤗 Transformers?
|
||||||
- local: add_new_pipeline
|
- local: add_new_pipeline
|
||||||
title: How to add a pipeline to 🤗 Transformers?
|
title: How to create a custom pipeline?
|
||||||
- local: testing
|
- local: testing
|
||||||
title: Testing
|
title: Testing
|
||||||
- local: pr_checks
|
- local: pr_checks
|
||||||
|
@ -9,7 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed
|
|||||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
-->
|
-->
|
||||||
|
|
||||||
# How to add a pipeline to 🤗 Transformers?
|
# How to create a custom pipeline?
|
||||||
|
|
||||||
|
In this guide, we will see how to create a custom pipeline and share it on the [Hub](hf.co/models) or add it to the
|
||||||
|
Transformers library.
|
||||||
|
|
||||||
First and foremost, you need to decide the raw entries the pipeline will be able to take. It can be strings, raw bytes,
|
First and foremost, you need to decide the raw entries the pipeline will be able to take. It can be strings, raw bytes,
|
||||||
dictionaries or whatever seems to be the most likely desired input. Try to keep these inputs as pure Python as possible
|
dictionaries or whatever seems to be the most likely desired input. Try to keep these inputs as pure Python as possible
|
||||||
@ -111,39 +114,123 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
|
|||||||
|
|
||||||
## Adding it to the list of supported tasks
|
## Adding it to the list of supported tasks
|
||||||
|
|
||||||
To register your `new-task` to the list of supported tasks, provide the
|
To register your `new-task` to the list of supported tasks, you have to add it to the `PIPELINE_REGISTRY`:
|
||||||
following task template:
|
|
||||||
|
|
||||||
```python
|
|
||||||
my_new_task = {
|
|
||||||
"impl": MyPipeline,
|
|
||||||
"tf": (),
|
|
||||||
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
|
||||||
"default": {"model": {"pt": "user/awesome_model"}},
|
|
||||||
"type": "audio", # current support type: text, audio, image, multimodal
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined.
|
|
||||||
If possible your custom task should provide a default model.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Then add your custom task to the list of supported tasks via
|
|
||||||
`PIPELINE_REGISTRY.register_pipeline()`:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers.pipelines import PIPELINE_REGISTRY
|
from transformers.pipelines import PIPELINE_REGISTRY
|
||||||
|
|
||||||
PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task)
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
|
"new-task",
|
||||||
|
pipeline_class=MyPipeline,
|
||||||
|
pt_model=AutoModelForSequenceClassification,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can specify a default model if you want, in which case it should come with a specific revision (which can be the name of a branch or a commit hash, here we took `"abcdef"`) as well was the type:
|
||||||
|
|
||||||
## Adding tests
|
```python
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
|
"new-task",
|
||||||
|
pipeline_class=MyPipeline,
|
||||||
|
pt_model=AutoModelForSequenceClassification,
|
||||||
|
default={"pt": ("user/awesome_model", "abcdef")},
|
||||||
|
type="text", # current support type: text, audio, image, multimodal
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests.
|
## Share your pipeline on the Hub
|
||||||
|
|
||||||
|
To share your custom pipeline on the Hub, you just have to save the custom code of your `Pipeline` subclass in a
|
||||||
|
python file. For instance, let's say we want to use a custom pipeline for sentence pair classification like this:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(outputs):
|
||||||
|
maxes = np.max(outputs, axis=-1, keepdims=True)
|
||||||
|
shifted_exp = np.exp(outputs - maxes)
|
||||||
|
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PairClassificationPipeline(Pipeline):
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
preprocess_kwargs = {}
|
||||||
|
if "second_text" in kwargs:
|
||||||
|
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
||||||
|
return preprocess_kwargs, {}, {}
|
||||||
|
|
||||||
|
def preprocess(self, text, second_text=None):
|
||||||
|
return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
return self.model(**model_inputs)
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs):
|
||||||
|
logits = model_outputs.logits[0].numpy()
|
||||||
|
probabilities = softmax(logits)
|
||||||
|
|
||||||
|
best_class = np.argmax(probabilities)
|
||||||
|
label = self.model.config.id2label[best_class]
|
||||||
|
score = probabilities[best_class].item()
|
||||||
|
logits = logits.tolist()
|
||||||
|
return {"label": label, "score": score, "logits": logits}
|
||||||
|
```
|
||||||
|
|
||||||
|
The implementation is framework agnostic, and will work for PyTorch and TensorFlow models. If we have saved this in
|
||||||
|
a file named `pair_classification.py`, we can then import it and register it like this:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from pair_classification import PairClassificationPipeline
|
||||||
|
from transformers.pipelines import PIPELINE_REGISTRY
|
||||||
|
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
|
||||||
|
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
|
"pair-classification",
|
||||||
|
pipeline_class=PairClassificationPipeline,
|
||||||
|
pt_model=AutoModelForSequenceClassification,
|
||||||
|
tf_model=TFAutoModelForSequenceClassification,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Once this is done, we can use it with a pretrained model. For instance `sgugger/finetuned-bert-mrpc` has been
|
||||||
|
fine-tuned on the MRPC dataset, which classifies pairs of sentences as paraphrases or not.
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
|
||||||
|
```
|
||||||
|
|
||||||
|
Then we can share it on the Hub by using the `save_pretrained` method in a `Repository`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from huggingface_hub import Repository
|
||||||
|
|
||||||
|
repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
|
||||||
|
classifier.save_pretrained("test-dynamic-pipeline")
|
||||||
|
repo.push_to_hub()
|
||||||
|
```
|
||||||
|
|
||||||
|
This will copy the file where you defined `PairClassificationPipeline` inside the folder `"test-dynamic-pipeline"`,
|
||||||
|
along with saving the model and tokenizer of the pipeline, before pushing everything in the repository
|
||||||
|
`{your_username}/test-dynamic-pipeline`. After that anyone can use it as long as they provide the option
|
||||||
|
`trust_remote_code=True`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
classifier = pipeline(model="{your_username}/test-dynamic-pipeline", trust_remote_code=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add the pipeline to Transformers
|
||||||
|
|
||||||
|
If you want to contribute your pipeline to Transformers, you will need to add a new module in the `pipelines` submodule
|
||||||
|
with the code of your pipeline, then add it in the list of tasks defined in `pipelines/__init__.py`.
|
||||||
|
|
||||||
|
Then you will need to add tests. Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests.
|
||||||
|
|
||||||
The `run_pipeline_test` function will be very generic and run on small random models on every possible
|
The `run_pipeline_test` function will be very generic and run on small random models on every possible
|
||||||
architecture as defined by `model_mapping` and `tf_model_mapping`.
|
architecture as defined by `model_mapping` and `tf_model_mapping`.
|
||||||
|
@ -23,7 +23,10 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from numpy import isin
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
|
from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||||
@ -379,6 +382,22 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
|||||||
return PIPELINE_REGISTRY.check_task(task)
|
return PIPELINE_REGISTRY.check_task(task)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_custom_task(task_info):
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
if "impl" not in task_info:
|
||||||
|
raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
|
||||||
|
pt_class_names = task_info.get("pt", ())
|
||||||
|
if isinstance(pt_class_names, str):
|
||||||
|
pt_class_names = [pt_class_names]
|
||||||
|
task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
|
||||||
|
tf_class_names = task_info.get("tf", ())
|
||||||
|
if isinstance(tf_class_names, str):
|
||||||
|
tf_class_names = [tf_class_names]
|
||||||
|
task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
|
||||||
|
return task_info, None
|
||||||
|
|
||||||
|
|
||||||
def pipeline(
|
def pipeline(
|
||||||
task: str = None,
|
task: str = None,
|
||||||
model: Optional = None,
|
model: Optional = None,
|
||||||
@ -391,6 +410,7 @@ def pipeline(
|
|||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
device_map=None,
|
device_map=None,
|
||||||
torch_dtype=None,
|
torch_dtype=None,
|
||||||
|
trust_remote_code: Optional[bool] = None,
|
||||||
model_kwargs: Dict[str, Any] = None,
|
model_kwargs: Dict[str, Any] = None,
|
||||||
pipeline_class: Optional[Any] = None,
|
pipeline_class: Optional[Any] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -488,6 +508,10 @@ def pipeline(
|
|||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
||||||
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
|
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
|
||||||
|
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
|
||||||
|
tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
|
||||||
|
and in which you have read the code, as it will execute code present on the Hub on your local machine.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
||||||
**model_kwargs)` function.
|
**model_kwargs)` function.
|
||||||
@ -516,6 +540,10 @@ def pipeline(
|
|||||||
```"""
|
```"""
|
||||||
if model_kwargs is None:
|
if model_kwargs is None:
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
||||||
|
# this is to keep BC).
|
||||||
|
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
|
||||||
|
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
|
||||||
|
|
||||||
if task is None and model is None:
|
if task is None and model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -537,6 +565,25 @@ def pipeline(
|
|||||||
" or a path/identifier to a pretrained model when providing feature_extractor."
|
" or a path/identifier to a pretrained model when providing feature_extractor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Config is the primordial information item.
|
||||||
|
# Instantiate config if needed
|
||||||
|
if isinstance(config, str):
|
||||||
|
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
elif config is None and isinstance(model, str):
|
||||||
|
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
|
||||||
|
custom_tasks = {}
|
||||||
|
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
|
||||||
|
custom_tasks = config.custom_pipelines
|
||||||
|
if task is None and trust_remote_code is not False:
|
||||||
|
if len(custom_tasks) == 1:
|
||||||
|
task = list(custom_tasks.keys())[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"We can't infer the task automatically for this model as there are multiple tasks available. Pick "
|
||||||
|
f"one in {', '.join(custom_tasks.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
if task is None and model is not None:
|
if task is None and model is not None:
|
||||||
if not isinstance(model, str):
|
if not isinstance(model, str):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -546,6 +593,21 @@ def pipeline(
|
|||||||
task = get_task(model, use_auth_token)
|
task = get_task(model, use_auth_token)
|
||||||
|
|
||||||
# Retrieve the task
|
# Retrieve the task
|
||||||
|
if task in custom_tasks:
|
||||||
|
targeted_task, task_options = clean_custom_task(custom_tasks[task])
|
||||||
|
if pipeline_class is None:
|
||||||
|
if not trust_remote_code:
|
||||||
|
raise ValueError(
|
||||||
|
"Loading this pipeline requires you to execute the code in the pipeline file in that"
|
||||||
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||||
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
|
)
|
||||||
|
class_ref = targeted_task["impl"]
|
||||||
|
module_file, class_name = class_ref.split(".")
|
||||||
|
pipeline_class = get_class_from_dynamic_module(
|
||||||
|
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
targeted_task, task_options = check_task(task)
|
targeted_task, task_options = check_task(task)
|
||||||
if pipeline_class is None:
|
if pipeline_class is None:
|
||||||
pipeline_class = targeted_task["impl"]
|
pipeline_class = targeted_task["impl"]
|
||||||
@ -560,9 +622,9 @@ def pipeline(
|
|||||||
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
||||||
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
||||||
)
|
)
|
||||||
|
if config is None and isinstance(model, str):
|
||||||
|
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
|
||||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
|
||||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -578,13 +640,6 @@ def pipeline(
|
|||||||
)
|
)
|
||||||
model_kwargs["torch_dtype"] = torch_dtype
|
model_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
# Config is the primordial information item.
|
|
||||||
# Instantiate config if needed
|
|
||||||
if isinstance(config, str):
|
|
||||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
|
||||||
elif config is None and isinstance(model, str):
|
|
||||||
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
|
|
||||||
|
|
||||||
model_name = model if isinstance(model, str) else None
|
model_name = model if isinstance(model, str) else None
|
||||||
|
|
||||||
# Infer the framework from the model
|
# Infer the framework from the model
|
||||||
@ -596,8 +651,8 @@ def pipeline(
|
|||||||
model_classes=model_classes,
|
model_classes=model_classes,
|
||||||
config=config,
|
config=config,
|
||||||
framework=framework,
|
framework=framework,
|
||||||
revision=revision,
|
|
||||||
task=task,
|
task=task,
|
||||||
|
**hub_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -641,7 +696,7 @@ def pipeline(
|
|||||||
tokenizer_kwargs = model_kwargs
|
tokenizer_kwargs = model_kwargs
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
|
tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if load_feature_extractor:
|
if load_feature_extractor:
|
||||||
@ -662,7 +717,7 @@ def pipeline(
|
|||||||
# Instantiate feature_extractor if needed
|
# Instantiate feature_extractor if needed
|
||||||
if isinstance(feature_extractor, (str, tuple)):
|
if isinstance(feature_extractor, (str, tuple)):
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
|
feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from ..dynamic_module_utils import custom_object_save
|
||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
@ -794,6 +795,27 @@ class Pipeline(_ScikitCompat):
|
|||||||
return
|
return
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
if hasattr(self, "_registered_impl"):
|
||||||
|
# Add info to the config
|
||||||
|
pipeline_info = self._registered_impl.copy()
|
||||||
|
custom_pipelines = {}
|
||||||
|
for task, info in pipeline_info.items():
|
||||||
|
if info["impl"] != self.__class__:
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = info.copy()
|
||||||
|
module_name = info["impl"].__module__
|
||||||
|
last_module = module_name.split(".")[-1]
|
||||||
|
# Change classes into their names/full names
|
||||||
|
info["impl"] = f"{last_module}.{info['impl'].__name__}"
|
||||||
|
info["pt"] = tuple(c.__name__ for c in info["pt"])
|
||||||
|
info["tf"] = tuple(c.__name__ for c in info["tf"])
|
||||||
|
|
||||||
|
custom_pipelines[task] = info
|
||||||
|
self.model.config.custom_pipelines = custom_pipelines
|
||||||
|
# Save the pipeline custom code
|
||||||
|
custom_object_save(self, save_directory)
|
||||||
|
|
||||||
self.model.save_pretrained(save_directory)
|
self.model.save_pretrained(save_directory)
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
@ -1117,11 +1139,40 @@ class PipelineRegistry:
|
|||||||
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
|
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None:
|
def register_pipeline(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
pipeline_class: type,
|
||||||
|
pt_model: Optional[Union[type, Tuple[type]]] = None,
|
||||||
|
tf_model: Optional[Union[type, Tuple[type]]] = None,
|
||||||
|
default: Optional[Dict] = None,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
if task in self.supported_tasks:
|
if task in self.supported_tasks:
|
||||||
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
|
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
|
||||||
|
|
||||||
|
if pt_model is None:
|
||||||
|
pt_model = ()
|
||||||
|
elif not isinstance(pt_model, tuple):
|
||||||
|
pt_model = (pt_model,)
|
||||||
|
|
||||||
|
if tf_model is None:
|
||||||
|
tf_model = ()
|
||||||
|
elif not isinstance(tf_model, tuple):
|
||||||
|
tf_model = (tf_model,)
|
||||||
|
|
||||||
|
task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model}
|
||||||
|
|
||||||
|
if default is not None:
|
||||||
|
if "model" not in default and ("pt" in default or "tf" in default):
|
||||||
|
default = {"model": default}
|
||||||
|
task_impl["default"] = default
|
||||||
|
|
||||||
|
if type is not None:
|
||||||
|
task_impl["type"] = type
|
||||||
|
|
||||||
self.supported_tasks[task] = task_impl
|
self.supported_tasks[task] = task_impl
|
||||||
|
pipeline_class._registered_impl = {task: task_impl}
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.supported_tasks
|
return self.supported_tasks
|
||||||
|
@ -15,15 +15,21 @@
|
|||||||
import copy
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from unittest import skipIf
|
from unittest import skipIf
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FEATURE_EXTRACTOR_MAPPING,
|
FEATURE_EXTRACTOR_MAPPING,
|
||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
@ -34,13 +40,17 @@ from transformers import (
|
|||||||
IBertConfig,
|
IBertConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
|
TFAutoModelForSequenceClassification,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
||||||
from transformers.pipelines.base import Pipeline, _pad
|
from transformers.pipelines.base import Pipeline, _pad
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
TOKEN,
|
||||||
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_staging_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_scatter,
|
require_scatter,
|
||||||
require_tensorflow_probability,
|
require_tensorflow_probability,
|
||||||
@ -48,9 +58,15 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import is_tf_available, is_torch_available
|
||||||
from transformers.utils import logging as transformers_logging
|
from transformers.utils import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -771,7 +787,7 @@ class CustomPipeline(Pipeline):
|
|||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class PipelineRegistryTest(unittest.TestCase):
|
class CustomPipelineTest(unittest.TestCase):
|
||||||
def test_warning_logs(self):
|
def test_warning_logs(self):
|
||||||
transformers_logging.set_verbosity_debug()
|
transformers_logging.set_verbosity_debug()
|
||||||
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||||
@ -783,25 +799,165 @@ class PipelineRegistryTest(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with CaptureLogger(logger_) as cm:
|
with CaptureLogger(logger_) as cm:
|
||||||
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
PIPELINE_REGISTRY.register_pipeline(alias, PairClassificationPipeline)
|
||||||
self.assertIn(f"{alias} is already registered", cm.out)
|
self.assertIn(f"{alias} is already registered", cm.out)
|
||||||
finally:
|
finally:
|
||||||
# restore
|
# restore
|
||||||
PIPELINE_REGISTRY.register_pipeline(alias, original_task)
|
PIPELINE_REGISTRY.supported_tasks[alias] = original_task
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_register_pipeline(self):
|
def test_register_pipeline(self):
|
||||||
custom_text_classification = {
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
"impl": CustomPipeline,
|
"custom-text-classification",
|
||||||
"tf": (),
|
pipeline_class=PairClassificationPipeline,
|
||||||
"pt": (AutoModelForSequenceClassification,),
|
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
|
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||||
"type": "text",
|
default={"pt": "hf-internal-testing/tiny-random-distilbert"},
|
||||||
}
|
type="text",
|
||||||
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
|
)
|
||||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||||
|
|
||||||
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||||
self.assertEqual(task_def, custom_text_classification)
|
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
|
||||||
|
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
||||||
self.assertEqual(task_def["type"], "text")
|
self.assertEqual(task_def["type"], "text")
|
||||||
self.assertEqual(task_def["impl"], CustomPipeline)
|
self.assertEqual(task_def["impl"], PairClassificationPipeline)
|
||||||
|
self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
|
||||||
|
|
||||||
|
# Clean registry for next tests.
|
||||||
|
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
|
||||||
|
|
||||||
|
def test_dynamic_pipeline(self):
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
|
"pair-classification",
|
||||||
|
pipeline_class=PairClassificationPipeline,
|
||||||
|
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
||||||
|
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Clean registry as we won't need the pipeline to be in it for the rest to work.
|
||||||
|
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
classifier.save_pretrained(tmp_dir)
|
||||||
|
# checks
|
||||||
|
self.assertDictEqual(
|
||||||
|
classifier.model.config.custom_pipelines,
|
||||||
|
{
|
||||||
|
"pair-classification": {
|
||||||
|
"impl": "custom_pipeline.PairClassificationPipeline",
|
||||||
|
"pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (),
|
||||||
|
"tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fails if the user forget to pass along `trust_remote_code=True`
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = pipeline(model=tmp_dir)
|
||||||
|
|
||||||
|
new_classifier = pipeline(model=tmp_dir, trust_remote_code=True)
|
||||||
|
# Using trust_remote_code=False forces the traditional pipeline tag
|
||||||
|
old_classifier = pipeline("text-classification", model=tmp_dir, trust_remote_code=False)
|
||||||
|
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
|
||||||
|
# dynamic module
|
||||||
|
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
|
||||||
|
self.assertEqual(new_classifier.task, "pair-classification")
|
||||||
|
results = new_classifier("I hate you", second_text="I love you")
|
||||||
|
self.assertDictEqual(
|
||||||
|
nested_simplify(results),
|
||||||
|
{"label": "LABEL_0", "score": 0.505, "logits": [-0.003, -0.024]},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
|
||||||
|
self.assertEqual(old_classifier.task, "text-classification")
|
||||||
|
results = old_classifier("I hate you", text_pair="I love you")
|
||||||
|
self.assertListEqual(
|
||||||
|
nested_simplify(results),
|
||||||
|
[{"label": "LABEL_0", "score": 0.505}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@is_staging_test
|
||||||
|
class DynamicPipelineTester(unittest.TestCase):
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "I", "love", "hate", "you"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
set_access_token(TOKEN)
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-pipeline")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub_dynamic_pipeline(self):
|
||||||
|
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
|
||||||
|
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(
|
||||||
|
"pair-classification",
|
||||||
|
pipeline_class=PairClassificationPipeline,
|
||||||
|
pt_model=AutoModelForSequenceClassification,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = BertForSequenceClassification(config).eval()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", use_auth_token=self._token)
|
||||||
|
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
|
classifier = pipeline("pair-classification", model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# Clean registry as we won't need the pipeline to be in it for the rest to work.
|
||||||
|
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
|
||||||
|
|
||||||
|
classifier.save_pretrained(tmp_dir)
|
||||||
|
# checks
|
||||||
|
self.assertDictEqual(
|
||||||
|
classifier.model.config.custom_pipelines,
|
||||||
|
{
|
||||||
|
"pair-classification": {
|
||||||
|
"impl": "custom_pipeline.PairClassificationPipeline",
|
||||||
|
"pt": ("AutoModelForSequenceClassification",),
|
||||||
|
"tf": (),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
repo.push_to_hub()
|
||||||
|
|
||||||
|
# Fails if the user forget to pass along `trust_remote_code=True`
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = pipeline(model=f"{USER}/test-dynamic-pipeline")
|
||||||
|
|
||||||
|
new_classifier = pipeline(model=f"{USER}/test-dynamic-pipeline", trust_remote_code=True)
|
||||||
|
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
|
||||||
|
# dynamic module
|
||||||
|
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
|
||||||
|
|
||||||
|
results = classifier("I hate you", second_text="I love you")
|
||||||
|
new_results = new_classifier("I hate you", second_text="I love you")
|
||||||
|
self.assertDictEqual(nested_simplify(results), nested_simplify(new_results))
|
||||||
|
|
||||||
|
# Using trust_remote_code=False forces the traditional pipeline tag
|
||||||
|
old_classifier = pipeline(
|
||||||
|
"text-classification", model=f"{USER}/test-dynamic-pipeline", trust_remote_code=False
|
||||||
|
)
|
||||||
|
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
|
||||||
|
self.assertEqual(old_classifier.task, "text-classification")
|
||||||
|
new_results = old_classifier("I hate you", text_pair="I love you")
|
||||||
|
self.assertListEqual(
|
||||||
|
nested_simplify([{"label": results["label"], "score": results["score"]}]), nested_simplify(new_results)
|
||||||
|
)
|
||||||
|
33
utils/test_module/custom_pipeline.py
Normal file
33
utils/test_module/custom_pipeline.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(outputs):
|
||||||
|
maxes = np.max(outputs, axis=-1, keepdims=True)
|
||||||
|
shifted_exp = np.exp(outputs - maxes)
|
||||||
|
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PairClassificationPipeline(Pipeline):
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
preprocess_kwargs = {}
|
||||||
|
if "second_text" in kwargs:
|
||||||
|
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
||||||
|
return preprocess_kwargs, {}, {}
|
||||||
|
|
||||||
|
def preprocess(self, text, second_text=None):
|
||||||
|
return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
return self.model(**model_inputs)
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs):
|
||||||
|
logits = model_outputs.logits[0].numpy()
|
||||||
|
probabilities = softmax(logits)
|
||||||
|
|
||||||
|
best_class = np.argmax(probabilities)
|
||||||
|
label = self.model.config.id2label[best_class]
|
||||||
|
score = probabilities[best_class].item()
|
||||||
|
logits = logits.tolist()
|
||||||
|
return {"label": label, "score": score, "logits": logits}
|
Loading…
Reference in New Issue
Block a user