Fixing the "translation", "translation_XX_to_YY" pipelines. (#7975)

* Actually make the "translation", "translation_XX_to_YY" task behave correctly.

Background:
- Currently "translation_cn_to_ar" does not work. (only 3 pairs are
supported)
- Some models, contain in their config the correct values for the (src,
tgt) pair they can translate. It's usually just one pair, and we can
infer it automatically from the `model.config.task_specific_params`. If
it's not defined we can still probably load the TranslationPipeline
nevertheless.

Proposed fix:
- A simplified version of what could become more general which is
a `parametrized` task. "translation" + (src, tgt) in this instance
it what we need in the general case. The way we go about it for now
is simply parsing "translation_XX_to_YY". If cases of parametrized task arise
we should preferably go in something closer to what `datasets` propose
which is having a secondary argument `task_options`? that will be close
to what that task requires.
- Should be backward compatible in all cases for instance
`pipeline(task="translation_en_to_de") should work out of the box.
- Should provide a warning when a specific translation pair has been
selected on behalf of the user using
`model.config.task_specific_params`.

* Update src/transformers/pipelines.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Nicolas Patry 2020-10-22 17:16:21 +02:00 committed by GitHub
parent 901e9b8eda
commit 18ce6b8ff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 111 additions and 21 deletions

View File

@ -20,6 +20,7 @@ import os
import pickle
import sys
import uuid
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import chain
@ -115,7 +116,7 @@ def get_framework(model):
return framework
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
@ -126,6 +127,9 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
framework (:obj:`str`, None)
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
task_options (:obj:`Any`, None)
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for translation task.
Returns
:obj:`str` The model string representing the default model for this pipeline
@ -135,7 +139,20 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
elif is_tf_available() and not is_torch_available():
framework = "tf"
default_models = targeted_task["default"]["model"]
defaults = targeted_task["default"]
if task_options:
if task_options not in defaults:
raise ValueError("The task does not provide any default models for options {}".format(task_options))
default_models = defaults[task_options]["model"]
elif "model" in defaults:
default_models = targeted_task["default"]["model"]
else:
# XXX This error message needs to be updated to be more generic if more tasks are going to become
# parametrized
raise ValueError(
'The task defaults can\'t be correctly selectionned. You probably meant "translation_XX_to_YY"'
)
if framework is None:
framework = "pt"
@ -2582,23 +2599,16 @@ SUPPORTED_TASKS = {
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
},
"translation_en_to_fr": {
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_de": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
"default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
@ -2631,6 +2641,49 @@ SUPPORTED_TASKS = {
}
def check_task(task: str) -> Tuple[Dict, Any]:
"""
Checks an incoming task string, to validate it's correct and return the
default Pipeline and Model classes, and default models if they exist.
Args:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"sentiment-analysis"`
- :obj:`"ner"`
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
- :obj:`"translation_xx_to_yy"`
- :obj:`"translation"`
- :obj:`"text-generation"`
- :obj:`"conversational"`
Returns:
(task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None))
The actual dictionnary required to initialize the pipeline and some
extra task options for parametrized tasks like "translation_XX_to_YY"
"""
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = SUPPORTED_TASKS["translation"]
return targeted_task, (tokens[1], tokens[3])
raise KeyError("Invalid translation task {}, use 'translation_XX_to_YY' format".format(task))
raise KeyError(
"Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()) + ["translation_XX_to_YY"])
)
def pipeline(
task: str,
model: Optional = None,
@ -2709,15 +2762,12 @@ def pipeline(
>>> pipeline('ner', model=model, tokenizer=tokenizer)
"""
# Retrieve the task
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
targeted_task = SUPPORTED_TASKS[task]
targeted_task, task_options = check_task(task)
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework)
model = get_default_model(targeted_task, framework, task_options)
framework = framework or get_framework(model)
@ -2776,5 +2826,16 @@ def pipeline(
"Trying to load the model with Tensorflow."
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
task = key
warnings.warn(
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
task
),
UserWarning,
)
break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)

View File

@ -1,6 +1,8 @@
import unittest
from typing import Iterable, List, Optional
import pytest
from transformers import pipeline
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device
@ -392,6 +394,33 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs,
)
@require_torch
@slow
def test_default_translations(self):
# We don't provide a default for this pair
with self.assertRaises(ValueError):
pipeline(task="translation_cn_to_ar")
# but we do for this one
pipeline(task="translation_en_to_de")
@require_torch
def test_translation_on_odd_language(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
pipeline(task="translation_cn_to_ar", model=model)
@require_torch
def test_translation_default_language_selection(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):
with self.assertRaises(ValueError):
pipeline(task="translation")
@require_tf
@slow
def test_tf_translation(self):