Moving translation pipeline to new testing scheme. (#13297)

* Moving `translation` pipeline to new testing scheme.

* Update tokenization mbart tests.
This commit is contained in:
Nicolas Patry 2021-08-27 12:26:17 +02:00 committed by GitHub
parent 319d840b46
commit a3f96f366a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 25 deletions

View File

@ -201,12 +201,14 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs

View File

@ -186,12 +186,14 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs

View File

@ -16,31 +16,85 @@ import unittest
import pytest
from transformers import pipeline
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MBart50TokenizerFast,
MBartForConditionalGeneration,
TranslationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin
from .test_pipelines_common import ANY, PipelineTestCaseMeta
if is_torch_available():
from transformers.models.mbart import MBartForConditionalGeneration
from transformers.models.mbart50 import MBart50TokenizerFast
@is_pipeline_test
class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
translator = TranslationPipeline(model=model, tokenizer=tokenizer)
try:
outputs = translator("Some string")
except ValueError:
# Triggered by m2m langages
src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2]
outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang)
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "translation_en_to_de"
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
large_models = [None] # Models tested with the @slow decorator
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
@require_torch
def test_small_model_pt(self):
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt")
outputs = translator("This is a test string", max_length=20)
self.assertEqual(
outputs,
[
{
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
}
],
)
@require_tf
def test_small_model_tf(self):
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="tf")
outputs = translator("This is a test string", max_length=20)
self.assertEqual(
outputs,
[
{
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
}
],
)
class TranslationEnToRoPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "translation_en_to_ro"
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
large_models = [None] # Models tested with the @slow decorator
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
@require_torch
def test_en_to_de_pt(self):
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="pt")
outputs = translator("This is a test string", max_length=20)
self.assertEqual(
outputs,
[
{
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
}
],
)
@require_tf
def test_en_to_de_tf(self):
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="tf")
outputs = translator("This is a test string", max_length=20)
self.assertEqual(
outputs,
[
{
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
}
],
)
@is_pipeline_test
@ -92,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
translator = pipeline(task="translation", model=model)
self.assertEqual(translator.task, "translation_en_to_de")
self.assertEquals(translator.src_lang, "en")
self.assertEquals(translator.tgt_lang, "de")
self.assertEqual(translator.src_lang, "en")
self.assertEqual(translator.tgt_lang, "de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):

View File

@ -235,7 +235,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
)
self.assertEqual(
nested_simplify(inputs),