Custom pipeline (#18079)

* Initial work

* More work

* Add tests for custom pipelines on the Hub

* Protect import

* Make the test work for TF as well

* Last PyTorch specific bit

* Add documentation

* Style

* Title in toc

* Bad names!

* Update docs/source/en/add_new_pipeline.mdx

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Auto stash before merge of "custom_pipeline" and "origin/custom_pipeline"

* Address review comments

* Address more review comments

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger 2022-07-19 12:02:35 +02:00 committed by GitHub
parent 3bb6356d4d
commit dc9147ff36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 439 additions and 57 deletions

View File

@ -102,7 +102,7 @@
- local: add_new_model
title: How to add a model to 🤗 Transformers?
- local: add_new_pipeline
title: How to add a pipeline to 🤗 Transformers?
title: How to create a custom pipeline?
- local: testing
title: Testing
- local: pr_checks

View File

@ -9,7 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
-->
# How to add a pipeline to 🤗 Transformers?
# How to create a custom pipeline?
In this guide, we will see how to create a custom pipeline and share it on the [Hub](hf.co/models) or add it to the
Transformers library.
First and foremost, you need to decide the raw entries the pipeline will be able to take. It can be strings, raw bytes,
dictionaries or whatever seems to be the most likely desired input. Try to keep these inputs as pure Python as possible
@ -111,39 +114,123 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
## Adding it to the list of supported tasks
To register your `new-task` to the list of supported tasks, provide the
following task template:
```python
my_new_task = {
"impl": MyPipeline,
"tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "user/awesome_model"}},
"type": "audio", # current support type: text, audio, image, multimodal
}
```
<Tip>
Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined.
If possible your custom task should provide a default model.
</Tip>
Then add your custom task to the list of supported tasks via
`PIPELINE_REGISTRY.register_pipeline()`:
To register your `new-task` to the list of supported tasks, you have to add it to the `PIPELINE_REGISTRY`:
```python
from transformers.pipelines import PIPELINE_REGISTRY
PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task)
PIPELINE_REGISTRY.register_pipeline(
"new-task",
pipeline_class=MyPipeline,
pt_model=AutoModelForSequenceClassification,
)
```
You can specify a default model if you want, in which case it should come with a specific revision (which can be the name of a branch or a commit hash, here we took `"abcdef"`) as well was the type:
## Adding tests
```python
PIPELINE_REGISTRY.register_pipeline(
"new-task",
pipeline_class=MyPipeline,
pt_model=AutoModelForSequenceClassification,
default={"pt": ("user/awesome_model", "abcdef")},
type="text", # current support type: text, audio, image, multimodal
)
```
Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests.
## Share your pipeline on the Hub
To share your custom pipeline on the Hub, you just have to save the custom code of your `Pipeline` subclass in a
python file. For instance, let's say we want to use a custom pipeline for sentence pair classification like this:
```py
import numpy as np
from transformers import Pipeline
def softmax(outputs):
maxes = np.max(outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class PairClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "second_text" in kwargs:
preprocess_kwargs["second_text"] = kwargs["second_text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, second_text=None):
return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs):
logits = model_outputs.logits[0].numpy()
probabilities = softmax(logits)
best_class = np.argmax(probabilities)
label = self.model.config.id2label[best_class]
score = probabilities[best_class].item()
logits = logits.tolist()
return {"label": label, "score": score, "logits": logits}
```
The implementation is framework agnostic, and will work for PyTorch and TensorFlow models. If we have saved this in
a file named `pair_classification.py`, we can then import it and register it like this:
```py
from pair_classification import PairClassificationPipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
PIPELINE_REGISTRY.register_pipeline(
"pair-classification",
pipeline_class=PairClassificationPipeline,
pt_model=AutoModelForSequenceClassification,
tf_model=TFAutoModelForSequenceClassification,
)
```
Once this is done, we can use it with a pretrained model. For instance `sgugger/finetuned-bert-mrpc` has been
fine-tuned on the MRPC dataset, which classifies pairs of sentences as paraphrases or not.
```py
from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```
Then we can share it on the Hub by using the `save_pretrained` method in a `Repository`:
```py
from huggingface_hub import Repository
repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
```
This will copy the file where you defined `PairClassificationPipeline` inside the folder `"test-dynamic-pipeline"`,
along with saving the model and tokenizer of the pipeline, before pushing everything in the repository
`{your_username}/test-dynamic-pipeline`. After that anyone can use it as long as they provide the option
`trust_remote_code=True`:
```py
from transformers import pipeline
classifier = pipeline(model="{your_username}/test-dynamic-pipeline", trust_remote_code=True)
```
## Add the pipeline to Transformers
If you want to contribute your pipeline to Transformers, you will need to add a new module in the `pipelines` submodule
with the code of your pipeline, then add it in the list of tasks defined in `pipelines/__init__.py`.
Then you will need to add tests. Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests.
The `run_pipeline_test` function will be very generic and run on small random models on every possible
architecture as defined by `model_mapping` and `tf_model_mapping`.

View File

@ -23,7 +23,10 @@ import os
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from numpy import isin
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import get_class_from_dynamic_module
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
@ -379,6 +382,22 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return PIPELINE_REGISTRY.check_task(task)
def clean_custom_task(task_info):
import transformers
if "impl" not in task_info:
raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
pt_class_names = task_info.get("pt", ())
if isinstance(pt_class_names, str):
pt_class_names = [pt_class_names]
task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
tf_class_names = task_info.get("tf", ())
if isinstance(tf_class_names, str):
tf_class_names = [tf_class_names]
task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
return task_info, None
def pipeline(
task: str = None,
model: Optional = None,
@ -391,6 +410,7 @@ def pipeline(
use_auth_token: Optional[Union[str, bool]] = None,
device_map=None,
torch_dtype=None,
trust_remote_code: Optional[bool] = None,
model_kwargs: Dict[str, Any] = None,
pipeline_class: Optional[Any] = None,
**kwargs
@ -488,6 +508,10 @@ def pipeline(
torch_dtype (`str` or `torch.dtype`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
and in which you have read the code, as it will execute code present on the Hub on your local machine.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
**model_kwargs)` function.
@ -516,6 +540,10 @@ def pipeline(
```"""
if model_kwargs is None:
model_kwargs = {}
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
# this is to keep BC).
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
if task is None and model is None:
raise RuntimeError(
@ -537,6 +565,25 @@ def pipeline(
" or a path/identifier to a pretrained model when providing feature_extractor."
)
# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
custom_tasks = config.custom_pipelines
if task is None and trust_remote_code is not False:
if len(custom_tasks) == 1:
task = list(custom_tasks.keys())[0]
else:
raise RuntimeError(
"We can't infer the task automatically for this model as there are multiple tasks available. Pick "
f"one in {', '.join(custom_tasks.keys())}"
)
if task is None and model is not None:
if not isinstance(model, str):
raise RuntimeError(
@ -546,9 +593,24 @@ def pipeline(
task = get_task(model, use_auth_token)
# Retrieve the task
targeted_task, task_options = check_task(task)
if pipeline_class is None:
pipeline_class = targeted_task["impl"]
if task in custom_tasks:
targeted_task, task_options = clean_custom_task(custom_tasks[task])
if pipeline_class is None:
if not trust_remote_code:
raise ValueError(
"Loading this pipeline requires you to execute the code in the pipeline file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
class_ref = targeted_task["impl"]
module_file, class_name = class_ref.split(".")
pipeline_class = get_class_from_dynamic_module(
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
)
else:
targeted_task, task_options = check_task(task)
if pipeline_class is None:
pipeline_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
@ -560,9 +622,9 @@ def pipeline(
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
"Using a pipeline without specifying a model name and revision in production is not recommended."
)
if config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
if device_map is not None:
if "device_map" in model_kwargs:
raise ValueError(
@ -578,13 +640,6 @@ def pipeline(
)
model_kwargs["torch_dtype"] = torch_dtype
# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
model_name = model if isinstance(model, str) else None
# Infer the framework from the model
@ -596,8 +651,8 @@ def pipeline(
model_classes=model_classes,
config=config,
framework=framework,
revision=revision,
task=task,
**hub_kwargs,
**model_kwargs,
)
@ -641,7 +696,7 @@ def pipeline(
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
)
if load_feature_extractor:
@ -662,7 +717,7 @@ def pipeline(
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
)
if (

View File

@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from packaging import version
from ..dynamic_module_utils import custom_object_save
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
@ -794,6 +795,27 @@ class Pipeline(_ScikitCompat):
return
os.makedirs(save_directory, exist_ok=True)
if hasattr(self, "_registered_impl"):
# Add info to the config
pipeline_info = self._registered_impl.copy()
custom_pipelines = {}
for task, info in pipeline_info.items():
if info["impl"] != self.__class__:
continue
info = info.copy()
module_name = info["impl"].__module__
last_module = module_name.split(".")[-1]
# Change classes into their names/full names
info["impl"] = f"{last_module}.{info['impl'].__name__}"
info["pt"] = tuple(c.__name__ for c in info["pt"])
info["tf"] = tuple(c.__name__ for c in info["tf"])
custom_pipelines[task] = info
self.model.config.custom_pipelines = custom_pipelines
# Save the pipeline custom code
custom_object_save(self, save_directory)
self.model.save_pretrained(save_directory)
if self.tokenizer is not None:
@ -1117,11 +1139,40 @@ class PipelineRegistry:
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
)
def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None:
def register_pipeline(
self,
task: str,
pipeline_class: type,
pt_model: Optional[Union[type, Tuple[type]]] = None,
tf_model: Optional[Union[type, Tuple[type]]] = None,
default: Optional[Dict] = None,
type: Optional[str] = None,
) -> None:
if task in self.supported_tasks:
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
if pt_model is None:
pt_model = ()
elif not isinstance(pt_model, tuple):
pt_model = (pt_model,)
if tf_model is None:
tf_model = ()
elif not isinstance(tf_model, tuple):
tf_model = (tf_model,)
task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model}
if default is not None:
if "model" not in default and ("pt" in default or "tf" in default):
default = {"model": default}
task_impl["default"] = default
if type is not None:
task_impl["type"] = type
self.supported_tasks[task] = task_impl
pipeline_class._registered_impl = {task: task_impl}
def to_dict(self):
return self.supported_tasks

View File

@ -15,15 +15,21 @@
import copy
import importlib
import logging
import os
import random
import string
import sys
import tempfile
import unittest
from abc import abstractmethod
from functools import lru_cache
from pathlib import Path
from unittest import skipIf
import numpy as np
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
@ -34,13 +40,17 @@ from transformers import (
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
TFAutoModelForSequenceClassification,
pipeline,
)
from transformers.pipelines import PIPELINE_REGISTRY, get_task
from transformers.pipelines.base import Pipeline, _pad
from transformers.testing_utils import (
TOKEN,
USER,
CaptureLogger,
is_pipeline_test,
is_staging_test,
nested_simplify,
require_scatter,
require_tensorflow_probability,
@ -48,9 +58,15 @@ from transformers.testing_utils import (
require_torch,
slow,
)
from transformers.utils import is_tf_available, is_torch_available
from transformers.utils import logging as transformers_logging
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
logger = logging.getLogger(__name__)
@ -771,7 +787,7 @@ class CustomPipeline(Pipeline):
@is_pipeline_test
class PipelineRegistryTest(unittest.TestCase):
class CustomPipelineTest(unittest.TestCase):
def test_warning_logs(self):
transformers_logging.set_verbosity_debug()
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
@ -783,25 +799,165 @@ class PipelineRegistryTest(unittest.TestCase):
try:
with CaptureLogger(logger_) as cm:
PIPELINE_REGISTRY.register_pipeline(alias, {})
PIPELINE_REGISTRY.register_pipeline(alias, PairClassificationPipeline)
self.assertIn(f"{alias} is already registered", cm.out)
finally:
# restore
PIPELINE_REGISTRY.register_pipeline(alias, original_task)
PIPELINE_REGISTRY.supported_tasks[alias] = original_task
@require_torch
def test_register_pipeline(self):
custom_text_classification = {
"impl": CustomPipeline,
"tf": (),
"pt": (AutoModelForSequenceClassification,),
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
"type": "text",
}
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
PIPELINE_REGISTRY.register_pipeline(
"custom-text-classification",
pipeline_class=PairClassificationPipeline,
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
default={"pt": "hf-internal-testing/tiny-random-distilbert"},
type="text",
)
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
self.assertEqual(task_def, custom_text_classification)
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
self.assertEqual(task_def["type"], "text")
self.assertEqual(task_def["impl"], CustomPipeline)
self.assertEqual(task_def["impl"], PairClassificationPipeline)
self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
# Clean registry for next tests.
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
def test_dynamic_pipeline(self):
PIPELINE_REGISTRY.register_pipeline(
"pair-classification",
pipeline_class=PairClassificationPipeline,
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
)
classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert")
# Clean registry as we won't need the pipeline to be in it for the rest to work.
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
with tempfile.TemporaryDirectory() as tmp_dir:
classifier.save_pretrained(tmp_dir)
# checks
self.assertDictEqual(
classifier.model.config.custom_pipelines,
{
"pair-classification": {
"impl": "custom_pipeline.PairClassificationPipeline",
"pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (),
"tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (),
}
},
)
# Fails if the user forget to pass along `trust_remote_code=True`
with self.assertRaises(ValueError):
_ = pipeline(model=tmp_dir)
new_classifier = pipeline(model=tmp_dir, trust_remote_code=True)
# Using trust_remote_code=False forces the traditional pipeline tag
old_classifier = pipeline("text-classification", model=tmp_dir, trust_remote_code=False)
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
# dynamic module
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
self.assertEqual(new_classifier.task, "pair-classification")
results = new_classifier("I hate you", second_text="I love you")
self.assertDictEqual(
nested_simplify(results),
{"label": "LABEL_0", "score": 0.505, "logits": [-0.003, -0.024]},
)
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
self.assertEqual(old_classifier.task, "text-classification")
results = old_classifier("I hate you", text_pair="I love you")
self.assertListEqual(
nested_simplify(results),
[{"label": "LABEL_0", "score": 0.505}],
)
@require_torch
@is_staging_test
class DynamicPipelineTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "I", "love", "hate", "you"]
@classmethod
def setUpClass(cls):
cls._token = TOKEN
set_access_token(TOKEN)
HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-dynamic-pipeline")
except HTTPError:
pass
def test_push_to_hub_dynamic_pipeline(self):
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
PIPELINE_REGISTRY.register_pipeline(
"pair-classification",
pipeline_class=PairClassificationPipeline,
pt_model=AutoModelForSequenceClassification,
)
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertForSequenceClassification(config).eval()
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", use_auth_token=self._token)
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
classifier = pipeline("pair-classification", model=model, tokenizer=tokenizer)
# Clean registry as we won't need the pipeline to be in it for the rest to work.
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
classifier.save_pretrained(tmp_dir)
# checks
self.assertDictEqual(
classifier.model.config.custom_pipelines,
{
"pair-classification": {
"impl": "custom_pipeline.PairClassificationPipeline",
"pt": ("AutoModelForSequenceClassification",),
"tf": (),
}
},
)
repo.push_to_hub()
# Fails if the user forget to pass along `trust_remote_code=True`
with self.assertRaises(ValueError):
_ = pipeline(model=f"{USER}/test-dynamic-pipeline")
new_classifier = pipeline(model=f"{USER}/test-dynamic-pipeline", trust_remote_code=True)
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
# dynamic module
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
results = classifier("I hate you", second_text="I love you")
new_results = new_classifier("I hate you", second_text="I love you")
self.assertDictEqual(nested_simplify(results), nested_simplify(new_results))
# Using trust_remote_code=False forces the traditional pipeline tag
old_classifier = pipeline(
"text-classification", model=f"{USER}/test-dynamic-pipeline", trust_remote_code=False
)
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
self.assertEqual(old_classifier.task, "text-classification")
new_results = old_classifier("I hate you", text_pair="I love you")
self.assertListEqual(
nested_simplify([{"label": results["label"], "score": results["score"]}]), nested_simplify(new_results)
)

View File

@ -0,0 +1,33 @@
import numpy as np
from transformers import Pipeline
def softmax(outputs):
maxes = np.max(outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class PairClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "second_text" in kwargs:
preprocess_kwargs["second_text"] = kwargs["second_text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, second_text=None):
return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs):
logits = model_outputs.logits[0].numpy()
probabilities = softmax(logits)
best_class = np.argmax(probabilities)
label = self.model.config.id2label[best_class]
score = probabilities[best_class].item()
logits = logits.tolist()
return {"label": label, "score": score, "logits": logits}