From 9e147d31f67a03ea4f5b11a5c7c3b7f8d252bfb7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 22 Feb 2021 12:36:16 -0500 Subject: [PATCH] Deprecate prepare_seq2seq_batch (#10287) * Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut Co-authored-by: Suraj Patil * More review comments Co-authored-by: Lysandre Debut Co-authored-by: Suraj Patil --- docs/source/model_doc/fsmt.rst | 2 +- docs/source/model_doc/marian.rst | 63 ++++++++++--------- docs/source/model_doc/mbart.rst | 47 ++++++++------ docs/source/model_doc/pegasus.rst | 28 ++++----- docs/source/model_doc/rag.rst | 2 +- docs/source/model_doc/t5.rst | 2 +- scripts/fsmt/fsmt-make-super-tiny-model.py | 2 +- scripts/fsmt/fsmt-make-tiny-model.py | 2 +- .../models/marian/modeling_marian.py | 11 ++-- .../models/marian/modeling_tf_marian.py | 11 ++-- .../models/marian/tokenization_marian.py | 11 ++-- .../models/mbart/tokenization_mbart.py | 42 +++++++------ .../models/mbart/tokenization_mbart_fast.py | 43 +++++++------ src/transformers/models/mt5/modeling_mt5.py | 16 +++-- .../models/mt5/modeling_tf_mt5.py | 18 ++++-- src/transformers/models/rag/modeling_rag.py | 28 +++++---- .../models/rag/tokenization_rag.py | 8 +++ src/transformers/tokenization_utils_base.py | 7 +++ tests/test_modeling_marian.py | 19 +++--- tests/test_modeling_mbart.py | 12 ++-- tests/test_modeling_tf_marian.py | 4 +- tests/test_modeling_tf_mbart.py | 4 +- tests/test_modeling_tf_pegasus.py | 4 +- tests/test_tokenization_bart.py | 45 +++++-------- tests/test_tokenization_barthez.py | 10 +-- tests/test_tokenization_marian.py | 8 ++- tests/test_tokenization_mbart.py | 63 +++++++++---------- tests/test_tokenization_mbart50.py | 60 +++++++++--------- tests/test_tokenization_pegasus.py | 14 +++-- tests/test_tokenization_prophetnet.py | 12 +--- tests/test_tokenization_t5.py | 47 +++++--------- 31 files changed, 325 insertions(+), 320 deletions(-) diff --git a/docs/source/model_doc/fsmt.rst b/docs/source/model_doc/fsmt.rst index eb9a21859ea..c60909f88d2 100644 --- a/docs/source/model_doc/fsmt.rst +++ b/docs/source/model_doc/fsmt.rst @@ -56,7 +56,7 @@ FSMTTokenizer .. autoclass:: transformers.FSMTTokenizer :members: build_inputs_with_special_tokens, get_special_tokens_mask, - create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary + create_token_type_ids_from_sequences, save_vocabulary FSMTModel diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index 18d515a8695..51018a4f79f 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -76,27 +76,29 @@ require 3 character language codes: .. code-block:: python - from transformers import MarianMTModel, MarianTokenizer - src_text = [ - '>>fra<< this is a sentence in english that we want to translate to french', - '>>por<< This should go to portuguese', - '>>esp<< And this to Spanish' - ] + >>> from transformers import MarianMTModel, MarianTokenizer + >>> src_text = [ + ... '>>fra<< this is a sentence in english that we want to translate to french', + ... '>>por<< This should go to portuguese', + ... '>>esp<< And this to Spanish' + >>> ] - model_name = 'Helsinki-NLP/opus-mt-en-roa' - tokenizer = MarianTokenizer.from_pretrained(model_name) - print(tokenizer.supported_language_codes) - model = MarianMTModel.from_pretrained(model_name) - translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")) - tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] - # ["c'est une phrase en anglais que nous voulons traduire en français", - # 'Isto deve ir para o português.', - # 'Y esto al español'] + >>> model_name = 'Helsinki-NLP/opus-mt-en-roa' + >>> tokenizer = MarianTokenizer.from_pretrained(model_name) + >>> print(tokenizer.supported_language_codes) + ['>>zlm_Latn<<', '>>mfe<<', '>>hat<<', '>>pap<<', '>>ast<<', '>>cat<<', '>>ind<<', '>>glg<<', '>>wln<<', '>>spa<<', '>>fra<<', '>>ron<<', '>>por<<', '>>ita<<', '>>oci<<', '>>arg<<', '>>min<<'] + + >>> model = MarianMTModel.from_pretrained(model_name) + >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) + >>> [tokenizer.decode(t, skip_special_tokens=True) for t in translated] + ["c'est une phrase en anglais que nous voulons traduire en français", + 'Isto deve ir para o português.', + 'Y esto al español'] -Code to see available pretrained models: +Here is the code to see all available pretrained models on the hub: .. code-block:: python @@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char .. code-block::python - from transformers import MarianMTModel, MarianTokenizer - src_text = [ - '>>fr<< this is a sentence in english that we want to translate to french', - '>>pt<< This should go to portuguese', - '>>es<< And this to Spanish' - ] + >>> from transformers import MarianMTModel, MarianTokenizer + >>> src_text = [ + ... '>>fr<< this is a sentence in english that we want to translate to french', + ... '>>pt<< This should go to portuguese', + ... '>>es<< And this to Spanish' + >>> ] - model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE' - tokenizer = MarianTokenizer.from_pretrained(model_name) - print(tokenizer.supported_language_codes) + >>> model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE' + >>> tokenizer = MarianTokenizer.from_pretrained(model_name) - model = MarianMTModel.from_pretrained(model_name) - translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")) - tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] - # ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español'] + >>> model = MarianMTModel.from_pretrained(model_name) + >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) + >>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] + ["c'est une phrase en anglais que nous voulons traduire en français", + 'Isto deve ir para o português.', + 'Y esto al español'] @@ -176,7 +179,7 @@ MarianTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MarianTokenizer - :members: prepare_seq2seq_batch + :members: as_target_tokenizer MarianModel diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index 05bbec1cbe7..05631ab0cab 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -34,22 +34,31 @@ The Authors' code can be found `here >> from transformers import MBartForConditionalGeneration, MBartTokenizer + + >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro") + >>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" + + >>> inputs = tokenizer(example_english_phrase, return_tensors="pt", src_lang="en_XX", tgt_lang="ro_RO") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(expected_translation_romanian, return_tensors="pt") + + >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") + >>> # forward pass + >>> model(**inputs, labels=batch['labels']) - Generation @@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning. .. code-block:: - from transformers import MBartForConditionalGeneration, MBartTokenizer - model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") - tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro") - article = "UN Chief Says There Is No Military Solution in Syria" - batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt") - translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"]) - translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] - assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> from transformers import MBartForConditionalGeneration, MBartTokenizer + + >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX") + >>> article = "UN Chief Says There Is No Military Solution in Syria" + >>> inputs = tokenizer(article, return_tensors="pt") + >>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"]) + >>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] + "Şeful ONU declară că nu există o soluţie militară în Siria" Overview of MBart-50 @@ -160,7 +169,7 @@ MBartTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MBartTokenizer - :members: build_inputs_with_special_tokens, prepare_seq2seq_batch + :members: as_target_tokenizer, build_inputs_with_special_tokens MBartTokenizerFast diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index ad582230e9a..61a37b07f77 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -78,20 +78,20 @@ Usage Example .. code-block:: python - from transformers import PegasusForConditionalGeneration, PegasusTokenizer - import torch - src_text = [ - """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" - ] + >>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer + >>> import torch + >>> src_text = [ + ... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" + >>> ] - model_name = 'google/pegasus-xsum' - torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' - tokenizer = PegasusTokenizer.from_pretrained(model_name) - model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) - batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device) - translated = model.generate(**batch) - tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) - assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers." + >>> model_name = 'google/pegasus-xsum' + >>> device = 'cuda' if torch.cuda.is_available() else 'cpu' + >>> tokenizer = PegasusTokenizer.from_pretrained(model_name) + >>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) + >>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device) + >>> translated = model.generate(**batch) + >>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) + >>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers." @@ -107,7 +107,7 @@ PegasusTokenizer warning: ``add_tokens`` does not work at the moment. .. autoclass:: transformers.PegasusTokenizer - :members: __call__, prepare_seq2seq_batch + :members: PegasusTokenizerFast diff --git a/docs/source/model_doc/rag.rst b/docs/source/model_doc/rag.rst index 06205a8cb59..3b7361b1657 100644 --- a/docs/source/model_doc/rag.rst +++ b/docs/source/model_doc/rag.rst @@ -56,7 +56,7 @@ RagTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.RagTokenizer - :members: prepare_seq2seq_batch + :members: Rag specific outputs diff --git a/docs/source/model_doc/t5.rst b/docs/source/model_doc/t5.rst index 0ff96d0a424..27425218d27 100644 --- a/docs/source/model_doc/t5.rst +++ b/docs/source/model_doc/t5.rst @@ -104,7 +104,7 @@ T5Tokenizer .. autoclass:: transformers.T5Tokenizer :members: build_inputs_with_special_tokens, get_special_tokens_mask, - create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary + create_token_type_ids_from_sequences, save_vocabulary T5TokenizerFast diff --git a/scripts/fsmt/fsmt-make-super-tiny-model.py b/scripts/fsmt/fsmt-make-super-tiny-model.py index 9821343faf8..4a6b8e0c1b4 100755 --- a/scripts/fsmt/fsmt-make-super-tiny-model.py +++ b/scripts/fsmt/fsmt-make-super-tiny-model.py @@ -71,7 +71,7 @@ tiny_model = FSMTForConditionalGeneration(config) print(f"num of params {tiny_model.num_parameters()}") # Test -batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt") +batch = tokenizer(["Making tiny model"], return_tensors="pt") outputs = tiny_model(**batch) print("test output:", len(outputs.logits[0])) diff --git a/scripts/fsmt/fsmt-make-tiny-model.py b/scripts/fsmt/fsmt-make-tiny-model.py index dc0beffef9d..431942c05dd 100755 --- a/scripts/fsmt/fsmt-make-tiny-model.py +++ b/scripts/fsmt/fsmt-make-tiny-model.py @@ -42,7 +42,7 @@ tiny_model = FSMTForConditionalGeneration(config) print(f"num of params {tiny_model.num_parameters()}") # Test -batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt") +batch = tokenizer(["Making tiny model"], return_tensors="pt") outputs = tiny_model(**batch) print("test output:", len(outputs.logits[0])) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5dbd782090f..ebffe7d8615 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r""" >>> src = 'fr' # source language >>> trg = 'en' # target language >>> sample_text = "où est l'arrêt de bus ?" - >>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' + >>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}' - >>> model = MarianMTModel.from_pretrained(mname) - >>> tok = MarianTokenizer.from_pretrained(mname) - >>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference + >>> model = MarianMTModel.from_pretrained(model_name) + >>> tokenizer = MarianTokenizer.from_pretrained(model_name) + >>> batch = tokenizer([sample_text], return_tensors="pt") >>> gen = model.generate(**batch) - >>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?" + >>> tokenizer.batch_decode(gen, skip_special_tokens=True) + "Where is the bus stop ?" """ MARIAN_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index c45189cb1c4..578493bace3 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r""" >>> src = 'fr' # source language >>> trg = 'en' # target language >>> sample_text = "où est l'arrêt de bus ?" - >>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' + >>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}' - >>> model = MarianMTModel.from_pretrained(mname) - >>> tok = MarianTokenizer.from_pretrained(mname) - >>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="tf") # don't need tgt_text for inference + >>> model = TFMarianMTModel.from_pretrained(model_name) + >>> tokenizer = MarianTokenizer.from_pretrained(model_name) + >>> batch = tokenizer([sample_text], return_tensors="tf") >>> gen = model.generate(**batch) - >>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?" + >>> tokenizer.batch_decode(gen, skip_special_tokens=True) + "Where is the bus stop ?" """ MARIAN_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index c026aa6539f..a12f8451a91 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer): Examples:: >>> from transformers import MarianTokenizer - >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') + >>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional - >>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt") - >>> # keys [input_ids, attention_mask, labels]. - >>> # model(**batch) should work + >>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True) + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True) + >>> inputs["labels"] = labels["input_ids"] + # keys [input_ids, attention_mask, labels]. + >>> outputs = model(**inputs) should work """ vocab_files_names = vocab_files_names diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py index 8b88c98e680..752ff3effed 100644 --- a/src/transformers/models/mbart/tokenization_mbart.py +++ b/src/transformers/models/mbart/tokenization_mbart.py @@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer): """ Construct an MBART tokenizer. - :class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer` and adds a new - :meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` - - Refer to superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the + :class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to + superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the initialization parameters and other methods. - .. warning:: - - ``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work - properly. - The tokenization method is `` `` for source language documents, and `` ``` for target language documents. Examples:: >>> from transformers import MBartTokenizer - >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO") >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - >>> batch: dict = tokenizer.prepare_seq2seq_batch( - ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt" - ... ) - + >>> inputs = tokenizer(example_english_phrase, return_tensors="pt) + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(expected_translation_romanian, return_tensors="pt") + >>> inputs["labels"] = labels["input_ids"] """ vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} @@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer): prefix_tokens: List[int] = [] suffix_tokens: List[int] = [] - def __init__(self, *args, tokenizer_file=None, **kwargs): - super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs) + def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs): + super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs) self.sp_model_size = len(self.sp_model) self.lang_code_to_id = { code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) } self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} - self.cur_lang_code = self.lang_code_to_id["en_XX"] self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self._additional_special_tokens = list(self.lang_code_to_id.keys()) - self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) + + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) @property def vocab_size(self): return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: @@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer): ) -> BatchEncoding: self.src_lang = src_lang self.tgt_lang = tgt_lang - self.set_src_lang_special_tokens(self.src_lang) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) @contextmanager diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py index 80e0efed804..a449895a068 100644 --- a/src/transformers/models/mbart/tokenization_mbart_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart_fast.py @@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE `__. - :class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds - a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`. - - Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning - the initialization parameters and other methods. - - .. warning:: - ``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work - properly. + :class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast`. Refer to + superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning the + initialization parameters and other methods. The tokenization method is `` `` for source language documents, and `` ``` for target language documents. @@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): Examples:: >>> from transformers import MBartTokenizerFast - >>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro') + >>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO") >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - >>> batch: dict = tokenizer.prepare_seq2seq_batch( - ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt" - ... ) + >>> inputs = tokenizer(example_english_phrase, return_tensors="pt) + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(expected_translation_romanian, return_tensors="pt") + >>> inputs["labels"] = labels["input_ids"] """ vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} @@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): prefix_tokens: List[int] = [] suffix_tokens: List[int] = [] - def __init__(self, *args, tokenizer_file=None, **kwargs): - super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs) - - self.cur_lang_code = self.convert_tokens_to_ids("en_XX") - self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) + def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs): + super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs) self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES}) + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang) + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: @@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ) -> BatchEncoding: self.src_lang = src_lang self.tgt_lang = tgt_lang - self.set_src_lang_special_tokens(self.src_lang) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) @contextmanager diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 9a504662d1e..8276dd472b2 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -31,13 +31,17 @@ class MT5Model(T5Model): alongside usage examples. Examples:: + >>> from transformers import MT5Model, T5Tokenizer >>> model = MT5Model.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> summary = "Weiter Verhandlung in Syrien." - >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") - >>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels) + >>> inputs = tokenizer(article, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(summary, return_tensors="pt") + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) >>> hidden_states = outputs.last_hidden_state """ model_type = "mt5" @@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration): appropriate documentation alongside usage examples. Examples:: + >>> from transformers import MT5ForConditionalGeneration, T5Tokenizer >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> summary = "Weiter Verhandlung in Syrien." - >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") - >>> outputs = model(**batch) + >>> inputs = tokenizer(article, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(summary, return_tensors="pt") + + >>> outputs = model(**inputs,labels=labels["input_ids"]) >>> loss = outputs.loss """ diff --git a/src/transformers/models/mt5/modeling_tf_mt5.py b/src/transformers/models/mt5/modeling_tf_mt5.py index 166b83565b1..cd160676937 100644 --- a/src/transformers/models/mt5/modeling_tf_mt5.py +++ b/src/transformers/models/mt5/modeling_tf_mt5.py @@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model): documentation alongside usage examples. Examples:: + >>> from transformers import TFMT5Model, T5Tokenizer >>> model = TFMT5Model.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> summary = "Weiter Verhandlung in Syrien." - >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") - >>> batch["decoder_input_ids"] = batch["labels"] - >>> del batch["labels"] - >>> outputs = model(batch) + >>> inputs = tokenizer(article, return_tensors="tf") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(summary, return_tensors="tf") + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) >>> hidden_states = outputs.last_hidden_state """ model_type = "mt5" @@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): appropriate documentation alongside usage examples. Examples:: + >>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer >>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> summary = "Weiter Verhandlung in Syrien." - >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") - >>> outputs = model(batch) + >>> inputs = tokenizer(article, return_tensors="tf") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(summary, return_tensors="tf") + + >>> outputs = model(**inputs,labels=labels["input_ids"]) >>> loss = outputs.loss """ diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5f893e11cde..5e9e8c356af 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel): >>> # initialize with RagRetriever to do everything in one forward call >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever) - >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") - >>> input_ids = input_dict["input_ids"] - >>> outputs = model(input_ids=input_ids) - + >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt") + >>> outputs = model(input_ids=inputs["input_ids"]) """ n_docs = n_docs if n_docs is not None else self.config.n_docs use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel): >>> # initialize with RagRetriever to do everything in one forward call >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) - >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") - >>> input_ids = input_dict["input_ids"] - >>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) + >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt") + >>> input_ids = inputs["input_ids"] + >>> labels = targets["input_ids"] + >>> outputs = model(input_ids=input_ids, labels=labels) >>> # or use retriever separately >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) @@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) >>> # 3. Forward to generator - >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels) """ n_docs = n_docs if n_docs is not None else self.config.n_docs exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score @@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel): >>> # initialize with RagRetriever to do everything in one forward call >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) - >>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") - >>> input_ids = input_dict["input_ids"] - >>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) + >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt") + >>> input_ids = inputs["input_ids"] + >>> labels = targets["input_ids"] + >>> outputs = model(input_ids=input_ids, labels=labels) >>> # or use retriever separately >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True) @@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel): >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) >>> # 3. Forward to generator - >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) + >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels) >>> # or directly generate >>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores) diff --git a/src/transformers/models/rag/tokenization_rag.py b/src/transformers/models/rag/tokenization_rag.py index 7b5916b78dd..d78a087bc76 100644 --- a/src/transformers/models/rag/tokenization_rag.py +++ b/src/transformers/models/rag/tokenization_rag.py @@ -14,6 +14,7 @@ # limitations under the License. """Tokenization classes for RAG.""" import os +import warnings from contextlib import contextmanager from typing import List, Optional @@ -88,6 +89,13 @@ class RagTokenizer: truncation: bool = True, **kwargs, ) -> BatchEncoding: + warnings.warn( + "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the " + "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` " + "context manager to prepare your targets. See the documentation of your specific tokenizer for more " + "details", + FutureWarning, + ) if max_length is None: max_length = self.current_tokenizer.model_max_length model_inputs = self( diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 00ecdc7e400..d4825bcbaea 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -3303,6 +3303,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ + warnings.warn( + "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the " + "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` " + "context manager to prepare your targets. See the documentation of your specific tokenizer for more " + "details", + FutureWarning, + ) # mBART-specific kwargs that should be ignored by other models. kwargs.pop("src_lang", None) kwargs.pop("tgt_lang", None) diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 7da01e043bc..8e2b5fc513f 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase): self.assertListEqual(self.expected_text, generated_words) def translate_src_text(self, **tokenizer_kwargs): - model_inputs = self.tokenizer.prepare_seq2seq_batch( - src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs - ).to(torch_device) + model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device) self.assertEqual(self.model.device, model_inputs.input_ids.device) generated_ids = self.model.generate( model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 @@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] expected_ids = [38, 121, 14, 697, 38848, 0] - model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to( - torch_device - ) + model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(tgt, return_tensors="pt") + model_inputs["labels"] = targets["input_ids"].to(torch_device) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) @@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): def test_unk_support(self): t = self.tokenizer - ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist() + ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist() expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] self.assertEqual(expected, ids) def test_pad_not_split(self): - input_ids_w_pad = ( - self.tokenizer.prepare_seq2seq_batch(["I am a small frog "], return_tensors="pt") - .input_ids[0] - .tolist() - ) + input_ids_w_pad = self.tokenizer(["I am a small frog "], return_tensors="pt").input_ids[0].tolist() expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad self.assertListEqual(expected_w_pad, input_ids_w_pad) diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 7355eb146a8..d51e6056bd5 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_enro_generate_one(self): - batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( + batch: BatchEncoding = self.tokenizer( ["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt" ).to(torch_device) translated_tokens = self.model.generate(**batch) @@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_enro_generate_batch(self): - batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to( - torch_device - ) + batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device) translated_tokens = self.model.generate(**batch) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) assert self.tgt_text == decoded @@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): @unittest.skip("This test is broken, still generates english") def test_cc25_generate(self): - inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device) + inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device) translated_tokens = self.model.generate( input_ids=inputs["input_ids"].to(torch_device), decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], @@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_fill_mask(self): - inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best I ever read!"], return_tensors="pt").to( - torch_device - ) + inputs = self.tokenizer(["One of the best I ever read!"], return_tensors="pt").to(torch_device) outputs = self.model.generate( inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 ) diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index a49c47bb197..8000e41b5fe 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase): self.assertListEqual(self.expected_text, generated_words) def translate_src_text(self, **tokenizer_kwargs): - model_inputs = self.tokenizer.prepare_seq2seq_batch( - src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" - ) + model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf") generated_ids = self.model.generate( model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 ) diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index 4891b00c382..228fe6a57b4 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase): self.assertListEqual(self.expected_text, generated_words) def translate_src_text(self, **tokenizer_kwargs): - model_inputs = self.tokenizer.prepare_seq2seq_batch( - src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" - ) + model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf") generated_ids = self.model.generate( model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 ) diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index 46ff69ec16a..adbd618859b 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase): assert self.expected_text == generated_words def translate_src_text(self, **tokenizer_kwargs): - model_inputs = self.tokenizer.prepare_seq2seq_batch( - src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf" - ) + model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf") generated_ids = self.model.generate( model_inputs.input_ids, attention_mask=model_inputs.attention_mask, diff --git a/tests/test_tokenization_bart.py b/tests/test_tokenization_bart.py index 7075359827f..1e5574e9dd6 100644 --- a/tests/test_tokenization_bart.py +++ b/tests/test_tokenization_bart.py @@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): return BartTokenizerFast.from_pretrained("facebook/bart-large") @require_torch - def test_prepare_seq2seq_batch(self): + def test_prepare_batch(self): src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" - ) + batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) self.assertEqual((2, 9), batch.input_ids.shape) @@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(expected_src_tokens, result) # Test that special tokens are reset - # Test Prepare Seq @require_torch - def test_seq2seq_batch_empty_target_text(self): + def test_prepare_batch_empty_target_text(self): src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: - batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") + batch = tokenizer(src_text, padding=True, return_tensors="pt") # check if input_ids are returned and no labels self.assertIn("input_ids", batch) self.assertIn("attention_mask", batch) @@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): self.assertNotIn("decoder_attention_mask", batch) @require_torch - def test_seq2seq_batch_max_target_length(self): - src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] + def test_as_target_tokenizer_target_length(self): tgt_text = [ "Summary of the text.", "Another summary.", ] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt" - ) - self.assertEqual(32, batch["labels"].shape[1]) - - # test None max_target_length - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt" - ) - self.assertEqual(32, batch["labels"].shape[1]) + with tokenizer.as_target_tokenizer(): + targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt") + self.assertEqual(32, targets["input_ids"].shape[1]) @require_torch - def test_seq2seq_batch_not_longer_than_maxlen(self): + def test_prepare_batch_not_longer_than_maxlen(self): for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: - batch = tokenizer.prepare_seq2seq_batch( - ["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt" + batch = tokenizer( + ["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt" ) self.assertIsInstance(batch, BatchEncoding) self.assertEqual(batch.input_ids.shape, (2, 1024)) @@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): "Summary of the text.", ] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: - batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt") - input_ids = batch["input_ids"] - labels = batch["labels"] + inputs = tokenizer(src_text, return_tensors="pt") + with tokenizer.as_target_tokenizer(): + targets = tokenizer(tgt_text, return_tensors="pt") + input_ids = inputs["input_ids"] + labels = targets["input_ids"] self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item()) self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item()) self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) diff --git a/tests/test_tokenization_barthez.py b/tests/test_tokenization_barthez.py index c5b89711ad7..afb8e48de34 100644 --- a/tests/test_tokenization_barthez.py +++ b/tests/test_tokenization_barthez.py @@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.tokenizer = tokenizer @require_torch - def test_prepare_seq2seq_batch(self): + def test_prepare_batch(self): src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] expected_src_tokens = [0, 57, 3018, 70307, 91, 2] - batch = self.tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" + batch = self.tokenizer( + src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt" ) self.assertIsInstance(batch, BatchEncoding) diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index 7f9e776a063..b5e02fb64bd 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_tokenizer_equivalence_en_de(self): en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") - batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None) + batch = en_de_tokenizer(["I am a small frog"], return_tensors=None) self.assertIsInstance(batch, BatchEncoding) expected = [38, 121, 14, 697, 38848, 0] self.assertListEqual(expected, batch.input_ids[0]) @@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_outputs_not_longer_than_maxlen(self): tok = self.get_tokenizer() - batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK) + batch = tok( + ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK + ) self.assertIsInstance(batch, BatchEncoding) self.assertEqual(batch.input_ids.shape, (2, 512)) def test_outputs_can_be_shorter(self): tok = self.get_tokenizer() - batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK) + batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK) self.assertIsInstance(batch_smaller, BatchEncoding) self.assertEqual(batch_smaller.input_ids.shape, (2, 10)) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 1376cd7e8bb..a67c75e1f49 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(cls.checkpoint_name) + cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained( + cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO" + ) cls.pad_token_id = 1 return cls @@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase): src_text = ["this is gunna be a long sentence " * 20] assert isinstance(src_text[0], str) desired_max_length = 10 - ids = self.tokenizer.prepare_seq2seq_batch( - src_text, - max_length=desired_max_length, - ).input_ids[0] + ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0] self.assertEqual(ids[-2], 2) self.assertEqual(ids[-1], EN_CODE) self.assertEqual(len(ids), desired_max_length) @@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase): new_tok = MBartTokenizer.from_pretrained(tmpdirname) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) - # prepare_seq2seq_batch tests below - @require_torch def test_batch_fairseq_parity(self): - batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + batch = self.tokenizer(self.src_text, padding=True) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist() - for k in batch: - batch[k] = batch[k].tolist() - # batch = {k: v.tolist() for k,v in batch.items()} # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 - # batch.decoder_inputs_ids[0][0] == assert batch.input_ids[1][-2:] == [2, EN_CODE] assert batch.decoder_input_ids[1][0] == RO_CODE assert batch.decoder_input_ids[1][-1] == 2 - assert batch.labels[1][-2:] == [2, RO_CODE] + assert labels[1][-2:].tolist() == [2, RO_CODE] @require_torch - def test_enro_tokenizer_prepare_seq2seq_batch(self): - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt" + def test_enro_tokenizer_prepare_batch(self): + batch = self.tokenizer( + self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt" ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer( + self.tgt_text, + padding=True, + truncation=True, + max_length=len(self.expected_src_tokens), + return_tensors="pt", + ) + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + self.assertIsInstance(batch, BatchEncoding) self.assertEqual((2, 14), batch.input_ids.shape) @@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) - def test_seq2seq_max_target_length(self): - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + def test_seq2seq_max_length(self): + batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt") + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) - # max_target_length will default to max_length if not specified - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) - self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 3) diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py index ddb95d6ec2b..f31d030c93a 100644 --- a/tests/test_tokenization_mbart50.py +++ b/tests/test_tokenization_mbart50.py @@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): src_text = ["this is gunna be a long sentence " * 20] assert isinstance(src_text[0], str) desired_max_length = 10 - ids = self.tokenizer.prepare_seq2seq_batch( - src_text, - max_length=desired_max_length, - ).input_ids[0] + ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0] self.assertEqual(ids[0], EN_CODE) self.assertEqual(ids[-1], 2) self.assertEqual(len(ids), desired_max_length) @@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): new_tok = MBart50Tokenizer.from_pretrained(tmpdirname) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) - # prepare_seq2seq_batch tests below - @require_torch def test_batch_fairseq_parity(self): - batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + batch = self.tokenizer(self.src_text, padding=True) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist() + labels = labels.tolist() - for k in batch: - batch[k] = batch[k].tolist() - # batch = {k: v.tolist() for k,v in batch.items()} # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 - # batch.decoder_inputs_ids[0][0] == assert batch.input_ids[1][0] == EN_CODE assert batch.input_ids[1][-1] == 2 - assert batch.labels[1][0] == RO_CODE - assert batch.labels[1][-1] == 2 + assert labels[1][0] == RO_CODE + assert labels[1][-1] == 2 assert batch.decoder_input_ids[1][:2] == [2, RO_CODE] @require_torch - def test_tokenizer_prepare_seq2seq_batch(self): - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt" + def test_tokenizer_prepare_batch(self): + batch = self.tokenizer( + self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt" ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer( + self.tgt_text, + padding=True, + truncation=True, + max_length=len(self.expected_src_tokens), + return_tensors="pt", + ) + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + self.assertIsInstance(batch, BatchEncoding) self.assertEqual((2, 14), batch.input_ids.shape) @@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) def test_seq2seq_max_target_length(self): - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt") + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) - # max_target_length will default to max_length if not specified - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt" - ) - batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) - self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 3) diff --git a/tests/test_tokenization_pegasus.py b/tests/test_tokenization_pegasus.py index 56889b96cf6..c9ee3ee09e1 100644 --- a/tests/test_tokenization_pegasus.py +++ b/tests/test_tokenization_pegasus.py @@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_large_seq2seq_truncation(self): src_texts = ["This is going to be way too long." * 150, "short example"] tgt_texts = ["not super long but more than 5 tokens", "tiny"] - batch = self._large_tokenizer.prepare_seq2seq_batch( - src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt" - ) + batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt") + with self._large_tokenizer.as_target_tokenizer(): + targets = self._large_tokenizer( + tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt" + ) + assert batch.input_ids.shape == (2, 1024) assert batch.attention_mask.shape == (2, 1024) - assert "labels" in batch # because tgt_texts was specified - assert batch.labels.shape == (2, 5) - assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel + assert targets["input_ids"].shape == (2, 5) + assert len(batch) == 2 # input_ids, attention_mask. diff --git a/tests/test_tokenization_prophetnet.py b/tests/test_tokenization_prophetnet.py index 918612329ff..c073304aa90 100644 --- a/tests/test_tokenization_prophetnet.py +++ b/tests/test_tokenization_prophetnet.py @@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) @require_torch - def test_prepare_seq2seq_batch(self): + def test_prepare_batch(self): tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased") src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102] - batch = tokenizer.prepare_seq2seq_batch( - src_text, - tgt_texts=tgt_text, - return_tensors="pt", - ) + batch = tokenizer(src_text, padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 9fbd50eaf5e..27cdf612cea 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) - def test_prepare_seq2seq_batch(self): + def test_prepare_batch(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] - batch = tokenizer.prepare_seq2seq_batch( - src_text, - tgt_texts=tgt_text, - return_tensors=FRAMEWORK, - ) + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) self.assertIsInstance(batch, BatchEncoding) result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) @@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_empty_target_text(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) self.assertIn("attention_mask", batch) self.assertNotIn("decoder_input_ids", batch) self.assertNotIn("decoder_attention_mask", batch) - def test_max_target_length(self): + def test_max_length(self): tokenizer = self.t5_base_tokenizer - src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."] tgt_text = [ "Summary of the text.", "Another summary.", ] - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK - ) - self.assertEqual(32, batch["labels"].shape[1]) - - # test None max_target_length - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK - ) - self.assertEqual(32, batch["labels"].shape[1]) + with tokenizer.as_target_tokenizer(): + targets = tokenizer( + tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + ) + self.assertEqual(32, targets["input_ids"].shape[1]) def test_outputs_not_longer_than_maxlen(self): tokenizer = self.t5_base_tokenizer - batch = tokenizer.prepare_seq2seq_batch( - ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK + batch = tokenizer( + ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK ) self.assertIsInstance(batch, BatchEncoding) self.assertEqual(batch.input_ids.shape, (2, 512)) @@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1] - batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) + batch = tokenizer(src_text) + with tokenizer.as_target_tokenizer(): + targets = tokenizer(tgt_text) - src_ids = list(batch.input_ids.numpy()[0]) - tgt_ids = list(batch.labels.numpy()[0]) - - self.assertEqual(expected_src_tokens, src_ids) - self.assertEqual(expected_tgt_tokens, tgt_ids) + self.assertEqual(expected_src_tokens, batch["input_ids"][0]) + self.assertEqual(expected_tgt_tokens, targets["input_ids"][0]) def test_token_type_ids(self): src_text_1 = ["A first paragraph for summarization."]