Fixing the "translation", "translation_XX_to_YY" pipelines. (#7975)

* Actually make the "translation", "translation_XX_to_YY" task behave correctly.

Background:
- Currently "translation_cn_to_ar" does not work. (only 3 pairs are
supported)
- Some models, contain in their config the correct values for the (src,
tgt) pair they can translate. It's usually just one pair, and we can
infer it automatically from the `model.config.task_specific_params`. If
it's not defined we can still probably load the TranslationPipeline
nevertheless.

Proposed fix:
- A simplified version of what could become more general which is
a `parametrized` task. "translation" + (src, tgt) in this instance
it what we need in the general case. The way we go about it for now
is simply parsing "translation_XX_to_YY". If cases of parametrized task arise
we should preferably go in something closer to what `datasets` propose
which is having a secondary argument `task_options`? that will be close
to what that task requires.
- Should be backward compatible in all cases for instance
`pipeline(task="translation_en_to_de") should work out of the box.
- Should provide a warning when a specific translation pair has been
selected on behalf of the user using
`model.config.task_specific_params`.

* Update src/transformers/pipelines.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Nicolas Patry 2020-10-22 17:16:21 +02:00 committed by GitHub
parent 901e9b8eda
commit 18ce6b8ff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 111 additions and 21 deletions

View File

@ -20,6 +20,7 @@ import os
import pickle import pickle
import sys import sys
import uuid import uuid
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain from itertools import chain
@ -115,7 +116,7 @@ def get_framework(model):
return framework return framework
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str: def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
""" """
Select a default model to use for a given task. Defaults to pytorch if ambiguous. Select a default model to use for a given task. Defaults to pytorch if ambiguous.
@ -126,6 +127,9 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
framework (:obj:`str`, None) framework (:obj:`str`, None)
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet. "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
task_options (:obj:`Any`, None)
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for translation task.
Returns Returns
:obj:`str` The model string representing the default model for this pipeline :obj:`str` The model string representing the default model for this pipeline
@ -135,7 +139,20 @@ def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
elif is_tf_available() and not is_torch_available(): elif is_tf_available() and not is_torch_available():
framework = "tf" framework = "tf"
default_models = targeted_task["default"]["model"] defaults = targeted_task["default"]
if task_options:
if task_options not in defaults:
raise ValueError("The task does not provide any default models for options {}".format(task_options))
default_models = defaults[task_options]["model"]
elif "model" in defaults:
default_models = targeted_task["default"]["model"]
else:
# XXX This error message needs to be updated to be more generic if more tasks are going to become
# parametrized
raise ValueError(
'The task defaults can\'t be correctly selectionned. You probably meant "translation_XX_to_YY"'
)
if framework is None: if framework is None:
framework = "pt" framework = "pt"
@ -2582,23 +2599,16 @@ SUPPORTED_TASKS = {
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
}, },
"translation_en_to_fr": { # This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {
}, ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
"translation_en_to_de": { ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
"impl": TranslationPipeline, ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, },
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"text2text-generation": { "text2text-generation": {
"impl": Text2TextGenerationPipeline, "impl": Text2TextGenerationPipeline,
@ -2631,6 +2641,49 @@ SUPPORTED_TASKS = {
} }
def check_task(task: str) -> Tuple[Dict, Any]:
"""
Checks an incoming task string, to validate it's correct and return the
default Pipeline and Model classes, and default models if they exist.
Args:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"sentiment-analysis"`
- :obj:`"ner"`
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
- :obj:`"translation_xx_to_yy"`
- :obj:`"translation"`
- :obj:`"text-generation"`
- :obj:`"conversational"`
Returns:
(task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None))
The actual dictionnary required to initialize the pipeline and some
extra task options for parametrized tasks like "translation_XX_to_YY"
"""
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = SUPPORTED_TASKS["translation"]
return targeted_task, (tokens[1], tokens[3])
raise KeyError("Invalid translation task {}, use 'translation_XX_to_YY' format".format(task))
raise KeyError(
"Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()) + ["translation_XX_to_YY"])
)
def pipeline( def pipeline(
task: str, task: str,
model: Optional = None, model: Optional = None,
@ -2709,15 +2762,12 @@ def pipeline(
>>> pipeline('ner', model=model, tokenizer=tokenizer) >>> pipeline('ner', model=model, tokenizer=tokenizer)
""" """
# Retrieve the task # Retrieve the task
if task not in SUPPORTED_TASKS: targeted_task, task_options = check_task(task)
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
targeted_task = SUPPORTED_TASKS[task]
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework) model = get_default_model(targeted_task, framework, task_options)
framework = framework or get_framework(model) framework = framework or get_framework(model)
@ -2776,5 +2826,16 @@ def pipeline(
"Trying to load the model with Tensorflow." "Trying to load the model with Tensorflow."
) )
model = model_class.from_pretrained(model, config=config, **model_kwargs) model = model_class.from_pretrained(model, config=config, **model_kwargs)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
task = key
warnings.warn(
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
task
),
UserWarning,
)
break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs) return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)

View File

@ -1,6 +1,8 @@
import unittest import unittest
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
import pytest
from transformers import pipeline from transformers import pipeline
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device
@ -392,6 +394,33 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs, invalid_inputs,
) )
@require_torch
@slow
def test_default_translations(self):
# We don't provide a default for this pair
with self.assertRaises(ValueError):
pipeline(task="translation_cn_to_ar")
# but we do for this one
pipeline(task="translation_en_to_de")
@require_torch
def test_translation_on_odd_language(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
pipeline(task="translation_cn_to_ar", model=model)
@require_torch
def test_translation_default_language_selection(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):
with self.assertRaises(ValueError):
pipeline(task="translation")
@require_tf @require_tf
@slow @slow
def test_tf_translation(self): def test_tf_translation(self):