mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 11:41:51 +06:00
Moving translation
pipeline to new testing scheme. (#13297)
* Moving `translation` pipeline to new testing scheme. * Update tokenization mbart tests.
This commit is contained in:
parent
319d840b46
commit
a3f96f366a
@ -201,12 +201,14 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
def _build_translation_inputs(
|
||||||
|
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||||
|
):
|
||||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -186,12 +186,14 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
def _build_translation_inputs(
|
||||||
|
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
||||||
|
):
|
||||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -16,31 +16,85 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import (
|
||||||
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
MBart50TokenizerFast,
|
||||||
|
MBartForConditionalGeneration,
|
||||||
|
TranslationPipeline,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
@is_pipeline_test
|
||||||
from transformers.models.mbart import MBartForConditionalGeneration
|
class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
from transformers.models.mbart50 import MBart50TokenizerFast
|
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||||
|
translator = TranslationPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
try:
|
||||||
|
outputs = translator("Some string")
|
||||||
|
except ValueError:
|
||||||
|
# Triggered by m2m langages
|
||||||
|
src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2]
|
||||||
|
outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang)
|
||||||
|
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
|
||||||
|
|
||||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
@require_torch
|
||||||
pipeline_task = "translation_en_to_de"
|
def test_small_model_pt(self):
|
||||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt")
|
||||||
large_models = [None] # Models tested with the @slow decorator
|
outputs = translator("This is a test string", max_length=20)
|
||||||
invalid_inputs = [4, "<mask>"]
|
self.assertEqual(
|
||||||
mandatory_keys = ["translation_text"]
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="tf")
|
||||||
|
outputs = translator("This is a test string", max_length=20)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
class TranslationEnToRoPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
@require_torch
|
||||||
pipeline_task = "translation_en_to_ro"
|
def test_en_to_de_pt(self):
|
||||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="pt")
|
||||||
large_models = [None] # Models tested with the @slow decorator
|
outputs = translator("This is a test string", max_length=20)
|
||||||
invalid_inputs = [4, "<mask>"]
|
self.assertEqual(
|
||||||
mandatory_keys = ["translation_text"]
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_en_to_de_tf(self):
|
||||||
|
translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="tf")
|
||||||
|
outputs = translator("This is a test string", max_length=20)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@ -92,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
|||||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||||
translator = pipeline(task="translation", model=model)
|
translator = pipeline(task="translation", model=model)
|
||||||
self.assertEqual(translator.task, "translation_en_to_de")
|
self.assertEqual(translator.task, "translation_en_to_de")
|
||||||
self.assertEquals(translator.src_lang, "en")
|
self.assertEqual(translator.src_lang, "en")
|
||||||
self.assertEquals(translator.tgt_lang, "de")
|
self.assertEqual(translator.tgt_lang, "de")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_translation_with_no_language_no_model_fails(self):
|
def test_translation_with_no_language_no_model_fails(self):
|
||||||
|
@ -235,7 +235,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_tokenizer_translation(self):
|
def test_tokenizer_translation(self):
|
||||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
inputs = self.tokenizer._build_translation_inputs(
|
||||||
|
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(inputs),
|
nested_simplify(inputs),
|
||||||
|
Loading…
Reference in New Issue
Block a user