mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixing the "translation", "translation_XX_to_YY" pipelines. (#7975)
* Actually make the "translation", "translation_XX_to_YY" task behave correctly. Background: - Currently "translation_cn_to_ar" does not work. (only 3 pairs are supported) - Some models, contain in their config the correct values for the (src, tgt) pair they can translate. It's usually just one pair, and we can infer it automatically from the `model.config.task_specific_params`. If it's not defined we can still probably load the TranslationPipeline nevertheless. Proposed fix: - A simplified version of what could become more general which is a `parametrized` task. "translation" + (src, tgt) in this instance it what we need in the general case. The way we go about it for now is simply parsing "translation_XX_to_YY". If cases of parametrized task arise we should preferably go in something closer to what `datasets` propose which is having a secondary argument `task_options`? that will be close to what that task requires. - Should be backward compatible in all cases for instance `pipeline(task="translation_en_to_de") should work out of the box. - Should provide a warning when a specific translation pair has been selected on behalf of the user using `model.config.task_specific_params`. * Update src/transformers/pipelines.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
901e9b8eda
commit
18ce6b8ff3
@ -20,6 +20,7 @@ import os
|
||||
import pickle
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
@ -115,7 +116,7 @@ def get_framework(model):
|
||||
return framework
|
||||
|
||||
|
||||
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
|
||||
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
|
||||
"""
|
||||
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
||||
|
||||
@ -126,6 +127,9 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
|
||||
framework (:obj:`str`, None)
|
||||
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
|
||||
|
||||
task_options (:obj:`Any`, None)
|
||||
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for translation task.
|
||||
|
||||
Returns
|
||||
|
||||
:obj:`str` The model string representing the default model for this pipeline
|
||||
@ -135,7 +139,20 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
framework = "tf"
|
||||
|
||||
default_models = targeted_task["default"]["model"]
|
||||
defaults = targeted_task["default"]
|
||||
if task_options:
|
||||
if task_options not in defaults:
|
||||
raise ValueError("The task does not provide any default models for options {}".format(task_options))
|
||||
default_models = defaults[task_options]["model"]
|
||||
elif "model" in defaults:
|
||||
default_models = targeted_task["default"]["model"]
|
||||
else:
|
||||
# XXX This error message needs to be updated to be more generic if more tasks are going to become
|
||||
# parametrized
|
||||
raise ValueError(
|
||||
'The task defaults can\'t be correctly selectionned. You probably meant "translation_XX_to_YY"'
|
||||
)
|
||||
|
||||
if framework is None:
|
||||
framework = "pt"
|
||||
|
||||
@ -2582,23 +2599,16 @@ SUPPORTED_TASKS = {
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
||||
},
|
||||
"translation_en_to_fr": {
|
||||
# This task is a special case as it's parametrized by SRC, TGT languages.
|
||||
"translation": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"translation_en_to_de": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"translation_en_to_ro": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
|
||||
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
"default": {
|
||||
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
},
|
||||
"text2text-generation": {
|
||||
"impl": Text2TextGenerationPipeline,
|
||||
@ -2631,6 +2641,49 @@ SUPPORTED_TASKS = {
|
||||
}
|
||||
|
||||
|
||||
def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
"""
|
||||
Checks an incoming task string, to validate it's correct and return the
|
||||
default Pipeline and Model classes, and default models if they exist.
|
||||
|
||||
Args:
|
||||
task (:obj:`str`):
|
||||
The task defining which pipeline will be returned. Currently accepted tasks are:
|
||||
|
||||
- :obj:`"feature-extraction"`
|
||||
- :obj:`"sentiment-analysis"`
|
||||
- :obj:`"ner"`
|
||||
- :obj:`"question-answering"`
|
||||
- :obj:`"fill-mask"`
|
||||
- :obj:`"summarization"`
|
||||
- :obj:`"translation_xx_to_yy"`
|
||||
- :obj:`"translation"`
|
||||
- :obj:`"text-generation"`
|
||||
- :obj:`"conversational"`
|
||||
|
||||
Returns:
|
||||
(task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None))
|
||||
The actual dictionnary required to initialize the pipeline and some
|
||||
extra task options for parametrized tasks like "translation_XX_to_YY"
|
||||
|
||||
|
||||
"""
|
||||
if task in SUPPORTED_TASKS:
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
return targeted_task, None
|
||||
|
||||
if task.startswith("translation"):
|
||||
tokens = task.split("_")
|
||||
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
||||
targeted_task = SUPPORTED_TASKS["translation"]
|
||||
return targeted_task, (tokens[1], tokens[3])
|
||||
raise KeyError("Invalid translation task {}, use 'translation_XX_to_YY' format".format(task))
|
||||
|
||||
raise KeyError(
|
||||
"Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()) + ["translation_XX_to_YY"])
|
||||
)
|
||||
|
||||
|
||||
def pipeline(
|
||||
task: str,
|
||||
model: Optional = None,
|
||||
@ -2709,15 +2762,12 @@ def pipeline(
|
||||
>>> pipeline('ner', model=model, tokenizer=tokenizer)
|
||||
"""
|
||||
# Retrieve the task
|
||||
if task not in SUPPORTED_TASKS:
|
||||
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
||||
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
targeted_task, task_options = check_task(task)
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
# At that point framework might still be undetermined
|
||||
model = get_default_model(targeted_task, framework)
|
||||
model = get_default_model(targeted_task, framework, task_options)
|
||||
|
||||
framework = framework or get_framework(model)
|
||||
|
||||
@ -2776,5 +2826,16 @@ def pipeline(
|
||||
"Trying to load the model with Tensorflow."
|
||||
)
|
||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
||||
if task == "translation" and model.config.task_specific_params:
|
||||
for key in model.config.task_specific_params:
|
||||
if key.startswith("translation"):
|
||||
task = key
|
||||
warnings.warn(
|
||||
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
|
||||
task
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
break
|
||||
|
||||
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
|
||||
|
@ -1,6 +1,8 @@
|
||||
import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
|
||||
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device
|
||||
@ -392,6 +394,33 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
invalid_inputs,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_default_translations(self):
|
||||
# We don't provide a default for this pair
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="translation_cn_to_ar")
|
||||
|
||||
# but we do for this one
|
||||
pipeline(task="translation_en_to_de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_on_odd_language(self):
|
||||
model = TRANSLATION_FINETUNED_MODELS[0][0]
|
||||
pipeline(task="translation_cn_to_ar", model=model)
|
||||
|
||||
@require_torch
|
||||
def test_translation_default_language_selection(self):
|
||||
model = TRANSLATION_FINETUNED_MODELS[0][0]
|
||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||
nlp = pipeline(task="translation", model=model)
|
||||
self.assertEqual(nlp.task, "translation_en_to_de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_with_no_language_no_model_fails(self):
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="translation")
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_translation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user