mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Moving translation
pipeline to new testing scheme. (#13297)
* Moving `translation` pipeline to new testing scheme. * Update tokenization mbart tests.
This commit is contained in:
parent
319d840b46
commit
a3f96f366a
@ -201,12 +201,14 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
def _build_translation_inputs(
|
||||
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||
):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
@ -186,12 +186,14 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
def _build_translation_inputs(
|
||||
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||
):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
@ -16,31 +16,85 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MBart50TokenizerFast,
|
||||
MBartForConditionalGeneration,
|
||||
TranslationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.mbart import MBartForConditionalGeneration
|
||||
from transformers.models.mbart50 import MBart50TokenizerFast
|
||||
@is_pipeline_test
|
||||
class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
translator = TranslationPipeline(model=model, tokenizer=tokenizer)
|
||||
try:
|
||||
outputs = translator("Some string")
|
||||
except ValueError:
|
||||
# Triggered by m2m langages
|
||||
src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2]
|
||||
outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang)
|
||||
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
|
||||
|
||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "translation_en_to_de"
|
||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
||||
large_models = [None] # Models tested with the @slow decorator
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["translation_text"]
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt")
|
||||
outputs = translator("This is a test string", max_length=20)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="tf")
|
||||
outputs = translator("This is a test string", max_length=20)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
class TranslationEnToRoPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "translation_en_to_ro"
|
||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
||||
large_models = [None] # Models tested with the @slow decorator
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["translation_text"]
|
||||
@require_torch
|
||||
def test_en_to_de_pt(self):
|
||||
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="pt")
|
||||
outputs = translator("This is a test string", max_length=20)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_en_to_de_tf(self):
|
||||
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="tf")
|
||||
outputs = translator("This is a test string", max_length=20)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@ -92,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||
translator = pipeline(task="translation", model=model)
|
||||
self.assertEqual(translator.task, "translation_en_to_de")
|
||||
self.assertEquals(translator.src_lang, "en")
|
||||
self.assertEquals(translator.tgt_lang, "de")
|
||||
self.assertEqual(translator.src_lang, "en")
|
||||
self.assertEqual(translator.tgt_lang, "de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_with_no_language_no_model_fails(self):
|
||||
|
@ -235,7 +235,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
inputs = self.tokenizer._build_translation_inputs(
|
||||
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
|
Loading…
Reference in New Issue
Block a user