Deprecate prepare_seq2seq_batch (#10287)

* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* More review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger 2021-02-22 12:36:16 -05:00 committed by GitHub
parent e73a3e1891
commit 9e147d31f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 325 additions and 320 deletions

View File

@ -56,7 +56,7 @@ FSMTTokenizer
.. autoclass:: transformers.FSMTTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
create_token_type_ids_from_sequences, save_vocabulary
FSMTModel

View File

@ -76,27 +76,29 @@ require 3 character language codes:
.. code-block:: python
from transformers import MarianMTModel, MarianTokenizer
src_text = [
'>>fra<< this is a sentence in english that we want to translate to french',
'>>por<< This should go to portuguese',
'>>esp<< And this to Spanish'
]
>>> from transformers import MarianMTModel, MarianTokenizer
>>> src_text = [
... '>>fra<< this is a sentence in english that we want to translate to french',
... '>>por<< This should go to portuguese',
... '>>esp<< And this to Spanish'
>>> ]
model_name = 'Helsinki-NLP/opus-mt-en-roa'
tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.',
# 'Y esto al español']
>>> model_name = 'Helsinki-NLP/opus-mt-en-roa'
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
>>> print(tokenizer.supported_language_codes)
['>>zlm_Latn<<', '>>mfe<<', '>>hat<<', '>>pap<<', '>>ast<<', '>>cat<<', '>>ind<<', '>>glg<<', '>>wln<<', '>>spa<<', '>>fra<<', '>>ron<<', '>>por<<', '>>ita<<', '>>oci<<', '>>arg<<', '>>min<<']
>>> model = MarianMTModel.from_pretrained(model_name)
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
>>> [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
["c'est une phrase en anglais que nous voulons traduire en français",
'Isto deve ir para o português.',
'Y esto al español']
Code to see available pretrained models:
Here is the code to see all available pretrained models on the hub:
.. code-block:: python
@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char
.. code-block::python
from transformers import MarianMTModel, MarianTokenizer
src_text = [
'>>fr<< this is a sentence in english that we want to translate to french',
'>>pt<< This should go to portuguese',
'>>es<< And this to Spanish'
]
>>> from transformers import MarianMTModel, MarianTokenizer
>>> src_text = [
... '>>fr<< this is a sentence in english that we want to translate to french',
... '>>pt<< This should go to portuguese',
... '>>es<< And this to Spanish'
>>> ]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
>>> model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español']
>>> model = MarianMTModel.from_pretrained(model_name)
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
["c'est une phrase en anglais que nous voulons traduire en français",
'Isto deve ir para o português.',
'Y esto al español']
@ -176,7 +179,7 @@ MarianTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianTokenizer
:members: prepare_seq2seq_batch
:members: as_target_tokenizer
MarianModel

View File

@ -34,22 +34,31 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
Training of MBart
_______________________________________________________________________________________________________________________
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
multilingual it expects the sequences in a different format. A special language id token is added in both the source
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
MBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for translation task. As the
model is multilingual it expects the sequences in a different format. A special language id token is added in both the
source and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The
target text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
The :meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` handles this automatically and should be used to encode
the sequences for sequence-to-sequence fine-tuning.
The regular :meth:`~transformers.MBartTokenizer.__call__` will encode source text format, and it should be wrapped
inside the context manager :meth:`~transformers.MBartTokenizer.as_target_tokenizer` to encode target text format.
- Supervised training
.. code-block::
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt")
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
>>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt", src_lang="en_XX", tgt_lang="ro_RO")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
>>> # forward pass
>>> model(**inputs, labels=batch['labels'])
- Generation
@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning.
.. code-block::
from transformers import MBartForConditionalGeneration, MBartTokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
article = "UN Chief Says There Is No Military Solution in Syria"
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt")
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
>>> article = "UN Chief Says There Is No Military Solution in Syria"
>>> inputs = tokenizer(article, return_tensors="pt")
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
"Şeful ONU declară că nu există o soluţie militară în Siria"
Overview of MBart-50
@ -160,7 +169,7 @@ MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
:members: as_target_tokenizer, build_inputs_with_special_tokens
MBartTokenizerFast

View File

@ -78,20 +78,20 @@ Usage Example
.. code-block:: python
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
src_text = [
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]
>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
>>> import torch
>>> src_text = [
... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
>>> ]
model_name = 'google/pegasus-xsum'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
>>> model_name = 'google/pegasus-xsum'
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
>>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
>>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
>>> translated = model.generate(**batch)
>>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
>>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
@ -107,7 +107,7 @@ PegasusTokenizer
warning: ``add_tokens`` does not work at the moment.
.. autoclass:: transformers.PegasusTokenizer
:members: __call__, prepare_seq2seq_batch
:members:
PegasusTokenizerFast

View File

@ -56,7 +56,7 @@ RagTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagTokenizer
:members: prepare_seq2seq_batch
:members:
Rag specific outputs

View File

@ -104,7 +104,7 @@ T5Tokenizer
.. autoclass:: transformers.T5Tokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
create_token_type_ids_from_sequences, save_vocabulary
T5TokenizerFast

View File

@ -71,7 +71,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}")
# Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
batch = tokenizer(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0]))

View File

@ -42,7 +42,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}")
# Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
batch = tokenizer(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0]))

View File

@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
>>> src = 'fr' # source language
>>> trg = 'en' # target language
>>> sample_text = "où est l'arrêt de bus ?"
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model = MarianMTModel.from_pretrained(mname)
>>> tok = MarianTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference
>>> model = MarianMTModel.from_pretrained(model_name)
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
>>> batch = tokenizer([sample_text], return_tensors="pt")
>>> gen = model.generate(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
"Where is the bus stop ?"
"""
MARIAN_INPUTS_DOCSTRING = r"""

View File

@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
>>> src = 'fr' # source language
>>> trg = 'en' # target language
>>> sample_text = "où est l'arrêt de bus ?"
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model = MarianMTModel.from_pretrained(mname)
>>> tok = MarianTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="tf") # don't need tgt_text for inference
>>> model = TFMarianMTModel.from_pretrained(model_name)
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
>>> batch = tokenizer([sample_text], return_tensors="tf")
>>> gen = model.generate(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
"Where is the bus stop ?"
"""
MARIAN_INPUTS_DOCSTRING = r"""

View File

@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer):
Examples::
>>> from transformers import MarianTokenizer
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
>>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
>>> # keys [input_ids, attention_mask, labels].
>>> # model(**batch) should work
>>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
>>> inputs["labels"] = labels["input_ids"]
# keys [input_ids, attention_mask, labels].
>>> outputs = model(**inputs) should work
"""
vocab_files_names = vocab_files_names

View File

@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer):
"""
Construct an MBART tokenizer.
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer` and adds a new
:meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch`
Refer to superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to
superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
initialization parameters and other methods.
.. warning::
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
Examples::
>>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
... )
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer):
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
}
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.cur_lang_code = self.lang_code_to_id["en_XX"]
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def vocab_size(self):
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
) -> BatchEncoding:
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager

View File

@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds
a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`.
Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning
the initialization parameters and other methods.
.. warning::
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
properly.
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast`. Refer to
superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning the
initialization parameters and other methods.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
Examples::
>>> from transformers import MBartTokenizerFast
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro')
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
... )
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
self.cur_lang_code = self.convert_tokens_to_ids("en_XX")
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
) -> BatchEncoding:
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager

View File

@ -31,13 +31,17 @@ class MT5Model(T5Model):
alongside usage examples.
Examples::
>>> from transformers import MT5Model, T5Tokenizer
>>> model = MT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
>>> inputs = tokenizer(article, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(**batch)
>>> inputs = tokenizer(article, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(**inputs,labels=labels["input_ids"])
>>> loss = outputs.loss
"""

View File

@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model):
documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5Model, T5Tokenizer
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> batch["decoder_input_ids"] = batch["labels"]
>>> del batch["labels"]
>>> outputs = model(batch)
>>> inputs = tokenizer(article, return_tensors="tf")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> outputs = model(batch)
>>> inputs = tokenizer(article, return_tensors="tf")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(**inputs,labels=labels["input_ids"])
>>> loss = outputs.loss
"""

View File

@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = input_dict["input_ids"]
>>> outputs = model(input_ids=input_ids)
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"])
"""
n_docs = n_docs if n_docs is not None else self.config.n_docs
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = input_dict["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> # or use retriever separately
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
>>> # 3. Forward to generator
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
"""
n_docs = n_docs if n_docs is not None else self.config.n_docs
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = input_dict["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> # or use retriever separately
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
>>> # 3. Forward to generator
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
>>> # or directly generate
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)

View File

@ -14,6 +14,7 @@
# limitations under the License.
"""Tokenization classes for RAG."""
import os
import warnings
from contextlib import contextmanager
from typing import List, Optional
@ -88,6 +89,13 @@ class RagTokenizer:
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
if max_length is None:
max_length = self.current_tokenizer.model_max_length
model_inputs = self(

View File

@ -3303,6 +3303,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
# mBART-specific kwargs that should be ignored by other models.
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)

View File

@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
).to(torch_device)
model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device)
self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
torch_device
)
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(tgt, return_tensors="pt")
model_inputs["labels"] = targets["input_ids"].to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
def test_unk_support(self):
t = self.tokenizer
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
self.assertEqual(expected, ids)
def test_pad_not_split(self):
input_ids_w_pad = (
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
.input_ids[0]
.tolist()
)
input_ids_w_pad = self.tokenizer(["I am a small frog <pad>"], return_tensors="pt").input_ids[0].tolist()
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
self.assertListEqual(expected_w_pad, input_ids_w_pad)

View File

@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate_one(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
batch: BatchEncoding = self.tokenizer(
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
).to(torch_device)
translated_tokens = self.model.generate(**batch)
@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate_batch(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
torch_device
)
batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
assert self.tgt_text == decoded
@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self):
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_fill_mask(self):
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
torch_device
)
inputs = self.tokenizer(["One of the best <mask> I ever read!"], return_tensors="pt").to(torch_device)
outputs = self.model.generate(
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
)

View File

@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
)

View File

@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
)

View File

@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,

View File

@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape)
@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
def test_prepare_batch_empty_target_text(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
batch = tokenizer(src_text, padding=True, return_tensors="pt")
# check if input_ids are returned and no labels
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
def test_as_target_tokenizer_target_length(self):
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
self.assertEqual(32, targets["input_ids"].shape[1])
@require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self):
def test_prepare_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
batch = tokenizer(
["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
labels = batch["labels"]
inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())

View File

@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.tokenizer = tokenizer
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 57, 3018, 70307, 91, 2]
batch = self.tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
batch = self.tokenizer(
src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)

View File

@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
batch = en_de_tokenizer(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0])
@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer()
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
batch = tok(
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer()
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))

View File

@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(cls.checkpoint_name)
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
)
cls.pad_token_id = 1
return cls
@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text,
max_length=desired_max_length,
).input_ids[0]
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase):
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
assert labels[1][-2:].tolist() == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)

View File

@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text,
max_length=desired_max_length,
).input_ids[0]
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(len(ids), desired_max_length)
@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
labels = labels.tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2
assert batch.labels[1][0] == RO_CODE
assert batch.labels[1][-1] == 2
assert labels[1][0] == RO_CODE
assert labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
@require_torch
def test_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
def test_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)

View File

@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer.prepare_seq2seq_batch(
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
)
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
assert targets["input_ids"].shape == (2, 5)
assert len(batch) == 2 # input_ids, attention_mask.

View File

@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors="pt",
)
batch = tokenizer(src_text, padding=True, return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)

View File

@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors=FRAMEWORK,
)
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_empty_target_text(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
def test_max_length(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["labels"].shape[1])
with tokenizer.as_target_tokenizer():
targets = tokenizer(
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
batch = tokenizer(
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
batch = tokenizer(src_text)
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text)
src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.labels.numpy()[0])
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]