From 43cb03a93d5fce9dc78376d8ab5639e9de7c7d7a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 1 Jul 2020 10:32:50 -0400 Subject: [PATCH] MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182) --- src/transformers/tokenization_marian.py | 17 ++++++++--------- tests/test_tokenization_marian.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index fb0d327a207..46ff3ff457c 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer): max_length: Optional[int] = None, pad_to_max_length: bool = True, return_tensors: str = "pt", + truncation_strategy="only_first", + padding="longest", ) -> BatchEncoding: """Prepare model inputs for translation. For best performance, translate one sentence at a time. Arguments: @@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer): raise ValueError(f"found empty string in src_texts: {src_texts}") self.current_spm = self.spm_source src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much - model_inputs: BatchEncoding = self.batch_encode_plus( - src_texts, + tokenizer_kwargs = dict( add_special_tokens=True, return_tensors=return_tensors, max_length=max_length, pad_to_max_length=pad_to_max_length, + truncation_strategy=truncation_strategy, + padding=padding, ) + model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) + if tgt_texts is None: return model_inputs self.current_spm = self.spm_target - decoder_inputs: BatchEncoding = self.batch_encode_plus( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - pad_to_max_length=pad_to_max_length, - ) + decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) for k, v in decoder_inputs.items(): model_inputs[f"decoder_{k}"] = v self.current_spm = self.spm_source diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index 91b4438a715..8a189651aa6 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f from transformers.tokenization_utils import BatchEncoding from .test_tokenization_common import TokenizerTesterMixin +from .utils import _torch_available SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") @@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} zh_code = ">>zh<<" ORG_NAME = "Helsinki-NLP/" +FRAMEWORK = "pt" if _torch_available else "tf" class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): contents = [x.name for x in Path(save_dir).glob("*")] self.assertIn("source.spm", contents) MarianTokenizer.from_pretrained(save_dir) + + def test_outputs_not_longer_than_maxlen(self): + tok = self.get_tokenizer() + + batch = tok.prepare_translation_batch( + ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + self.assertEqual(batch.input_ids.shape, (2, 512)) + + def test_outputs_can_be_shorter(self): + tok = self.get_tokenizer() + batch_smaller = tok.prepare_translation_batch( + ["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch_smaller, BatchEncoding) + self.assertEqual(batch_smaller.input_ids.shape, (2, 10))