diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7ca43eca35a..2dbce02d4bc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -102,7 +102,7 @@ - local: add_new_model title: How to add a model to 🤗 Transformers? - local: add_new_pipeline - title: How to add a pipeline to 🤗 Transformers? + title: How to create a custom pipeline? - local: testing title: Testing - local: pr_checks diff --git a/docs/source/en/add_new_pipeline.mdx b/docs/source/en/add_new_pipeline.mdx index 1b07e651e60..6f2bb44acc4 100644 --- a/docs/source/en/add_new_pipeline.mdx +++ b/docs/source/en/add_new_pipeline.mdx @@ -9,7 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the --> -# How to add a pipeline to 🤗 Transformers? +# How to create a custom pipeline? + +In this guide, we will see how to create a custom pipeline and share it on the [Hub](hf.co/models) or add it to the +Transformers library. First and foremost, you need to decide the raw entries the pipeline will be able to take. It can be strings, raw bytes, dictionaries or whatever seems to be the most likely desired input. Try to keep these inputs as pure Python as possible @@ -111,39 +114,123 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes) ## Adding it to the list of supported tasks -To register your `new-task` to the list of supported tasks, provide the -following task template: - -```python -my_new_task = { - "impl": MyPipeline, - "tf": (), - "pt": (AutoModelForAudioClassification,) if is_torch_available() else (), - "default": {"model": {"pt": "user/awesome_model"}}, - "type": "audio", # current support type: text, audio, image, multimodal -} -``` - - - -Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined. -If possible your custom task should provide a default model. - - - -Then add your custom task to the list of supported tasks via -`PIPELINE_REGISTRY.register_pipeline()`: +To register your `new-task` to the list of supported tasks, you have to add it to the `PIPELINE_REGISTRY`: ```python from transformers.pipelines import PIPELINE_REGISTRY -PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task) +PIPELINE_REGISTRY.register_pipeline( + "new-task", + pipeline_class=MyPipeline, + pt_model=AutoModelForSequenceClassification, +) ``` +You can specify a default model if you want, in which case it should come with a specific revision (which can be the name of a branch or a commit hash, here we took `"abcdef"`) as well was the type: -## Adding tests +```python +PIPELINE_REGISTRY.register_pipeline( + "new-task", + pipeline_class=MyPipeline, + pt_model=AutoModelForSequenceClassification, + default={"pt": ("user/awesome_model", "abcdef")}, + type="text", # current support type: text, audio, image, multimodal +) +``` -Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests. +## Share your pipeline on the Hub + +To share your custom pipeline on the Hub, you just have to save the custom code of your `Pipeline` subclass in a +python file. For instance, let's say we want to use a custom pipeline for sentence pair classification like this: + +```py +import numpy as np + +from transformers import Pipeline + + +def softmax(outputs): + maxes = np.max(outputs, axis=-1, keepdims=True) + shifted_exp = np.exp(outputs - maxes) + return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) + + +class PairClassificationPipeline(Pipeline): + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "second_text" in kwargs: + preprocess_kwargs["second_text"] = kwargs["second_text"] + return preprocess_kwargs, {}, {} + + def preprocess(self, text, second_text=None): + return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework) + + def _forward(self, model_inputs): + return self.model(**model_inputs) + + def postprocess(self, model_outputs): + logits = model_outputs.logits[0].numpy() + probabilities = softmax(logits) + + best_class = np.argmax(probabilities) + label = self.model.config.id2label[best_class] + score = probabilities[best_class].item() + logits = logits.tolist() + return {"label": label, "score": score, "logits": logits} +``` + +The implementation is framework agnostic, and will work for PyTorch and TensorFlow models. If we have saved this in +a file named `pair_classification.py`, we can then import it and register it like this: + +```py +from pair_classification import PairClassificationPipeline +from transformers.pipelines import PIPELINE_REGISTRY +from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification + +PIPELINE_REGISTRY.register_pipeline( + "pair-classification", + pipeline_class=PairClassificationPipeline, + pt_model=AutoModelForSequenceClassification, + tf_model=TFAutoModelForSequenceClassification, +) +``` + +Once this is done, we can use it with a pretrained model. For instance `sgugger/finetuned-bert-mrpc` has been +fine-tuned on the MRPC dataset, which classifies pairs of sentences as paraphrases or not. + +```py +from transformers import pipeline + +classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc") +``` + +Then we can share it on the Hub by using the `save_pretrained` method in a `Repository`: + +```py +from huggingface_hub import Repository + +repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline") +classifier.save_pretrained("test-dynamic-pipeline") +repo.push_to_hub() +``` + +This will copy the file where you defined `PairClassificationPipeline` inside the folder `"test-dynamic-pipeline"`, +along with saving the model and tokenizer of the pipeline, before pushing everything in the repository +`{your_username}/test-dynamic-pipeline`. After that anyone can use it as long as they provide the option +`trust_remote_code=True`: + +```py +from transformers import pipeline + +classifier = pipeline(model="{your_username}/test-dynamic-pipeline", trust_remote_code=True) +``` + +## Add the pipeline to Transformers + +If you want to contribute your pipeline to Transformers, you will need to add a new module in the `pipelines` submodule +with the code of your pipeline, then add it in the list of tasks defined in `pipelines/__init__.py`. + +Then you will need to add tests. Create a new file `tests/test_pipelines_MY_PIPELINE.py` with example with the other tests. The `run_pipeline_test` function will be very generic and run on small random models on every possible architecture as defined by `model_mapping` and `tf_model_mapping`. diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index e563b284272..9fecd6c27a8 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -23,7 +23,10 @@ import os import warnings from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from numpy import isin + from ..configuration_utils import PretrainedConfig +from ..dynamic_module_utils import get_class_from_dynamic_module from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..models.auto.configuration_auto import AutoConfig from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor @@ -379,6 +382,22 @@ def check_task(task: str) -> Tuple[Dict, Any]: return PIPELINE_REGISTRY.check_task(task) +def clean_custom_task(task_info): + import transformers + + if "impl" not in task_info: + raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.") + pt_class_names = task_info.get("pt", ()) + if isinstance(pt_class_names, str): + pt_class_names = [pt_class_names] + task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names) + tf_class_names = task_info.get("tf", ()) + if isinstance(tf_class_names, str): + tf_class_names = [tf_class_names] + task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names) + return task_info, None + + def pipeline( task: str = None, model: Optional = None, @@ -391,6 +410,7 @@ def pipeline( use_auth_token: Optional[Union[str, bool]] = None, device_map=None, torch_dtype=None, + trust_remote_code: Optional[bool] = None, model_kwargs: Dict[str, Any] = None, pipeline_class: Optional[Any] = None, **kwargs @@ -488,6 +508,10 @@ def pipeline( torch_dtype (`str` or `torch.dtype`, *optional*): Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model (`torch.float16`, `torch.bfloat16`, ... or `"auto"`). + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom code defined on the Hub in their own modeling, configuration, + tokenization or even pipeline files. This option should only be set to `True` for repositories you trust + and in which you have read the code, as it will execute code present on the Hub on your local machine. model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., **model_kwargs)` function. @@ -516,6 +540,10 @@ def pipeline( ```""" if model_kwargs is None: model_kwargs = {} + # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs, + # this is to keep BC). + use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token) + hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code} if task is None and model is None: raise RuntimeError( @@ -537,6 +565,25 @@ def pipeline( " or a path/identifier to a pretrained model when providing feature_extractor." ) + # Config is the primordial information item. + # Instantiate config if needed + if isinstance(config, str): + config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs) + elif config is None and isinstance(model, str): + config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) + + custom_tasks = {} + if config is not None and len(getattr(config, "custom_pipelines", {})) > 0: + custom_tasks = config.custom_pipelines + if task is None and trust_remote_code is not False: + if len(custom_tasks) == 1: + task = list(custom_tasks.keys())[0] + else: + raise RuntimeError( + "We can't infer the task automatically for this model as there are multiple tasks available. Pick " + f"one in {', '.join(custom_tasks.keys())}" + ) + if task is None and model is not None: if not isinstance(model, str): raise RuntimeError( @@ -546,9 +593,24 @@ def pipeline( task = get_task(model, use_auth_token) # Retrieve the task - targeted_task, task_options = check_task(task) - if pipeline_class is None: - pipeline_class = targeted_task["impl"] + if task in custom_tasks: + targeted_task, task_options = clean_custom_task(custom_tasks[task]) + if pipeline_class is None: + if not trust_remote_code: + raise ValueError( + "Loading this pipeline requires you to execute the code in the pipeline file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + class_ref = targeted_task["impl"] + module_file, class_name = class_ref.split(".") + pipeline_class = get_class_from_dynamic_module( + model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token + ) + else: + targeted_task, task_options = check_task(task) + if pipeline_class is None: + pipeline_class = targeted_task["impl"] # Use default model/config/tokenizer for the task if no model is provided if model is None: @@ -560,9 +622,9 @@ def pipeline( f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n" "Using a pipeline without specifying a model name and revision in production is not recommended." ) + if config is None and isinstance(model, str): + config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) - # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained - model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) if device_map is not None: if "device_map" in model_kwargs: raise ValueError( @@ -578,13 +640,6 @@ def pipeline( ) model_kwargs["torch_dtype"] = torch_dtype - # Config is the primordial information item. - # Instantiate config if needed - if isinstance(config, str): - config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) - elif config is None and isinstance(model, str): - config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs) - model_name = model if isinstance(model, str) else None # Infer the framework from the model @@ -596,8 +651,8 @@ def pipeline( model_classes=model_classes, config=config, framework=framework, - revision=revision, task=task, + **hub_kwargs, **model_kwargs, ) @@ -641,7 +696,7 @@ def pipeline( tokenizer_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained( - tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs + tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs ) if load_feature_extractor: @@ -662,7 +717,7 @@ def pipeline( # Instantiate feature_extractor if needed if isinstance(feature_extractor, (str, tuple)): feature_extractor = AutoFeatureExtractor.from_pretrained( - feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs + feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs ) if ( diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 0e2b9ac2b87..29a12e7df22 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from packaging import version +from ..dynamic_module_utils import custom_object_save from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..modelcard import ModelCard from ..models.auto.configuration_auto import AutoConfig @@ -794,6 +795,27 @@ class Pipeline(_ScikitCompat): return os.makedirs(save_directory, exist_ok=True) + if hasattr(self, "_registered_impl"): + # Add info to the config + pipeline_info = self._registered_impl.copy() + custom_pipelines = {} + for task, info in pipeline_info.items(): + if info["impl"] != self.__class__: + continue + + info = info.copy() + module_name = info["impl"].__module__ + last_module = module_name.split(".")[-1] + # Change classes into their names/full names + info["impl"] = f"{last_module}.{info['impl'].__name__}" + info["pt"] = tuple(c.__name__ for c in info["pt"]) + info["tf"] = tuple(c.__name__ for c in info["tf"]) + + custom_pipelines[task] = info + self.model.config.custom_pipelines = custom_pipelines + # Save the pipeline custom code + custom_object_save(self, save_directory) + self.model.save_pretrained(save_directory) if self.tokenizer is not None: @@ -1117,11 +1139,40 @@ class PipelineRegistry: f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}" ) - def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None: + def register_pipeline( + self, + task: str, + pipeline_class: type, + pt_model: Optional[Union[type, Tuple[type]]] = None, + tf_model: Optional[Union[type, Tuple[type]]] = None, + default: Optional[Dict] = None, + type: Optional[str] = None, + ) -> None: if task in self.supported_tasks: logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...") + if pt_model is None: + pt_model = () + elif not isinstance(pt_model, tuple): + pt_model = (pt_model,) + + if tf_model is None: + tf_model = () + elif not isinstance(tf_model, tuple): + tf_model = (tf_model,) + + task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model} + + if default is not None: + if "model" not in default and ("pt" in default or "tf" in default): + default = {"model": default} + task_impl["default"] = default + + if type is not None: + task_impl["type"] = type + self.supported_tasks[task] = task_impl + pipeline_class._registered_impl = {task: task_impl} def to_dict(self): return self.supported_tasks diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9c3a94c64c6..83474a5ba04 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -15,15 +15,21 @@ import copy import importlib import logging +import os import random import string +import sys +import tempfile import unittest from abc import abstractmethod from functools import lru_cache +from pathlib import Path from unittest import skipIf import numpy as np +from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token +from requests.exceptions import HTTPError from transformers import ( FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, @@ -34,13 +40,17 @@ from transformers import ( IBertConfig, RobertaConfig, TextClassificationPipeline, + TFAutoModelForSequenceClassification, pipeline, ) from transformers.pipelines import PIPELINE_REGISTRY, get_task from transformers.pipelines.base import Pipeline, _pad from transformers.testing_utils import ( + TOKEN, + USER, CaptureLogger, is_pipeline_test, + is_staging_test, nested_simplify, require_scatter, require_tensorflow_probability, @@ -48,9 +58,15 @@ from transformers.testing_utils import ( require_torch, slow, ) +from transformers.utils import is_tf_available, is_torch_available from transformers.utils import logging as transformers_logging +sys.path.append(str(Path(__file__).parent.parent.parent / "utils")) + +from test_module.custom_pipeline import PairClassificationPipeline # noqa E402 + + logger = logging.getLogger(__name__) @@ -771,7 +787,7 @@ class CustomPipeline(Pipeline): @is_pipeline_test -class PipelineRegistryTest(unittest.TestCase): +class CustomPipelineTest(unittest.TestCase): def test_warning_logs(self): transformers_logging.set_verbosity_debug() logger_ = transformers_logging.get_logger("transformers.pipelines.base") @@ -783,25 +799,165 @@ class PipelineRegistryTest(unittest.TestCase): try: with CaptureLogger(logger_) as cm: - PIPELINE_REGISTRY.register_pipeline(alias, {}) + PIPELINE_REGISTRY.register_pipeline(alias, PairClassificationPipeline) self.assertIn(f"{alias} is already registered", cm.out) finally: # restore - PIPELINE_REGISTRY.register_pipeline(alias, original_task) + PIPELINE_REGISTRY.supported_tasks[alias] = original_task - @require_torch def test_register_pipeline(self): - custom_text_classification = { - "impl": CustomPipeline, - "tf": (), - "pt": (AutoModelForSequenceClassification,), - "default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}}, - "type": "text", - } - PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification) + PIPELINE_REGISTRY.register_pipeline( + "custom-text-classification", + pipeline_class=PairClassificationPipeline, + pt_model=AutoModelForSequenceClassification if is_torch_available() else None, + tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None, + default={"pt": "hf-internal-testing/tiny-random-distilbert"}, + type="text", + ) assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks() task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification") - self.assertEqual(task_def, custom_text_classification) + self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ()) + self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ()) self.assertEqual(task_def["type"], "text") - self.assertEqual(task_def["impl"], CustomPipeline) + self.assertEqual(task_def["impl"], PairClassificationPipeline) + self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}}) + + # Clean registry for next tests. + del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"] + + def test_dynamic_pipeline(self): + PIPELINE_REGISTRY.register_pipeline( + "pair-classification", + pipeline_class=PairClassificationPipeline, + pt_model=AutoModelForSequenceClassification if is_torch_available() else None, + tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None, + ) + + classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert") + + # Clean registry as we won't need the pipeline to be in it for the rest to work. + del PIPELINE_REGISTRY.supported_tasks["pair-classification"] + + with tempfile.TemporaryDirectory() as tmp_dir: + classifier.save_pretrained(tmp_dir) + # checks + self.assertDictEqual( + classifier.model.config.custom_pipelines, + { + "pair-classification": { + "impl": "custom_pipeline.PairClassificationPipeline", + "pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (), + "tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (), + } + }, + ) + # Fails if the user forget to pass along `trust_remote_code=True` + with self.assertRaises(ValueError): + _ = pipeline(model=tmp_dir) + + new_classifier = pipeline(model=tmp_dir, trust_remote_code=True) + # Using trust_remote_code=False forces the traditional pipeline tag + old_classifier = pipeline("text-classification", model=tmp_dir, trust_remote_code=False) + # Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a + # dynamic module + self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline") + self.assertEqual(new_classifier.task, "pair-classification") + results = new_classifier("I hate you", second_text="I love you") + self.assertDictEqual( + nested_simplify(results), + {"label": "LABEL_0", "score": 0.505, "logits": [-0.003, -0.024]}, + ) + + self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline") + self.assertEqual(old_classifier.task, "text-classification") + results = old_classifier("I hate you", text_pair="I love you") + self.assertListEqual( + nested_simplify(results), + [{"label": "LABEL_0", "score": 0.505}], + ) + + +@require_torch +@is_staging_test +class DynamicPipelineTester(unittest.TestCase): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "I", "love", "hate", "you"] + + @classmethod + def setUpClass(cls): + cls._token = TOKEN + set_access_token(TOKEN) + HfFolder.save_token(TOKEN) + + @classmethod + def tearDownClass(cls): + try: + delete_repo(token=cls._token, repo_id="test-dynamic-pipeline") + except HTTPError: + pass + + def test_push_to_hub_dynamic_pipeline(self): + from transformers import BertConfig, BertForSequenceClassification, BertTokenizer + + PIPELINE_REGISTRY.register_pipeline( + "pair-classification", + pipeline_class=PairClassificationPipeline, + pt_model=AutoModelForSequenceClassification, + ) + + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertForSequenceClassification(config).eval() + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", use_auth_token=self._token) + + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = BertTokenizer(vocab_file) + + classifier = pipeline("pair-classification", model=model, tokenizer=tokenizer) + + # Clean registry as we won't need the pipeline to be in it for the rest to work. + del PIPELINE_REGISTRY.supported_tasks["pair-classification"] + + classifier.save_pretrained(tmp_dir) + # checks + self.assertDictEqual( + classifier.model.config.custom_pipelines, + { + "pair-classification": { + "impl": "custom_pipeline.PairClassificationPipeline", + "pt": ("AutoModelForSequenceClassification",), + "tf": (), + } + }, + ) + + repo.push_to_hub() + + # Fails if the user forget to pass along `trust_remote_code=True` + with self.assertRaises(ValueError): + _ = pipeline(model=f"{USER}/test-dynamic-pipeline") + + new_classifier = pipeline(model=f"{USER}/test-dynamic-pipeline", trust_remote_code=True) + # Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a + # dynamic module + self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline") + + results = classifier("I hate you", second_text="I love you") + new_results = new_classifier("I hate you", second_text="I love you") + self.assertDictEqual(nested_simplify(results), nested_simplify(new_results)) + + # Using trust_remote_code=False forces the traditional pipeline tag + old_classifier = pipeline( + "text-classification", model=f"{USER}/test-dynamic-pipeline", trust_remote_code=False + ) + self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline") + self.assertEqual(old_classifier.task, "text-classification") + new_results = old_classifier("I hate you", text_pair="I love you") + self.assertListEqual( + nested_simplify([{"label": results["label"], "score": results["score"]}]), nested_simplify(new_results) + ) diff --git a/utils/test_module/custom_pipeline.py b/utils/test_module/custom_pipeline.py new file mode 100644 index 00000000000..4c7928b1ccd --- /dev/null +++ b/utils/test_module/custom_pipeline.py @@ -0,0 +1,33 @@ +import numpy as np + +from transformers import Pipeline + + +def softmax(outputs): + maxes = np.max(outputs, axis=-1, keepdims=True) + shifted_exp = np.exp(outputs - maxes) + return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) + + +class PairClassificationPipeline(Pipeline): + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "second_text" in kwargs: + preprocess_kwargs["second_text"] = kwargs["second_text"] + return preprocess_kwargs, {}, {} + + def preprocess(self, text, second_text=None): + return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework) + + def _forward(self, model_inputs): + return self.model(**model_inputs) + + def postprocess(self, model_outputs): + logits = model_outputs.logits[0].numpy() + probabilities = softmax(logits) + + best_class = np.argmax(probabilities) + label = self.model.config.id2label[best_class] + score = probabilities[best_class].item() + logits = logits.tolist() + return {"label": label, "score": score, "logits": logits}