mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
e73a3e1891
commit
9e147d31f6
@ -56,7 +56,7 @@ FSMTTokenizer
|
||||
|
||||
.. autoclass:: transformers.FSMTTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
FSMTModel
|
||||
|
@ -76,27 +76,29 @@ require 3 character language codes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer
|
||||
src_text = [
|
||||
'>>fra<< this is a sentence in english that we want to translate to french',
|
||||
'>>por<< This should go to portuguese',
|
||||
'>>esp<< And this to Spanish'
|
||||
]
|
||||
>>> from transformers import MarianMTModel, MarianTokenizer
|
||||
>>> src_text = [
|
||||
... '>>fra<< this is a sentence in english that we want to translate to french',
|
||||
... '>>por<< This should go to portuguese',
|
||||
... '>>esp<< And this to Spanish'
|
||||
>>> ]
|
||||
|
||||
model_name = 'Helsinki-NLP/opus-mt-en-roa'
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
print(tokenizer.supported_language_codes)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
# 'Isto deve ir para o português.',
|
||||
# 'Y esto al español']
|
||||
>>> model_name = 'Helsinki-NLP/opus-mt-en-roa'
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> print(tokenizer.supported_language_codes)
|
||||
['>>zlm_Latn<<', '>>mfe<<', '>>hat<<', '>>pap<<', '>>ast<<', '>>cat<<', '>>ind<<', '>>glg<<', '>>wln<<', '>>spa<<', '>>fra<<', '>>ron<<', '>>por<<', '>>ita<<', '>>oci<<', '>>arg<<', '>>min<<']
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
|
||||
>>> [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
'Isto deve ir para o português.',
|
||||
'Y esto al español']
|
||||
|
||||
|
||||
|
||||
|
||||
Code to see available pretrained models:
|
||||
Here is the code to see all available pretrained models on the hub:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char
|
||||
|
||||
.. code-block::python
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer
|
||||
src_text = [
|
||||
'>>fr<< this is a sentence in english that we want to translate to french',
|
||||
'>>pt<< This should go to portuguese',
|
||||
'>>es<< And this to Spanish'
|
||||
]
|
||||
>>> from transformers import MarianMTModel, MarianTokenizer
|
||||
>>> src_text = [
|
||||
... '>>fr<< this is a sentence in english that we want to translate to french',
|
||||
... '>>pt<< This should go to portuguese',
|
||||
... '>>es<< And this to Spanish'
|
||||
>>> ]
|
||||
|
||||
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
print(tokenizer.supported_language_codes)
|
||||
>>> model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español']
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
|
||||
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
'Isto deve ir para o português.',
|
||||
'Y esto al español']
|
||||
|
||||
|
||||
|
||||
@ -176,7 +179,7 @@ MarianTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MarianTokenizer
|
||||
:members: prepare_seq2seq_batch
|
||||
:members: as_target_tokenizer
|
||||
|
||||
|
||||
MarianModel
|
||||
|
@ -34,22 +34,31 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
|
||||
Training of MBart
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
|
||||
multilingual it expects the sequences in a different format. A special language id token is added in both the source
|
||||
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
|
||||
text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
|
||||
MBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for translation task. As the
|
||||
model is multilingual it expects the sequences in a different format. A special language id token is added in both the
|
||||
source and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The
|
||||
target text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
|
||||
|
||||
The :meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` handles this automatically and should be used to encode
|
||||
the sequences for sequence-to-sequence fine-tuning.
|
||||
The regular :meth:`~transformers.MBartTokenizer.__call__` will encode source text format, and it should be wrapped
|
||||
inside the context manager :meth:`~transformers.MBartTokenizer.as_target_tokenizer` to encode target text format.
|
||||
|
||||
- Supervised training
|
||||
|
||||
.. code-block::
|
||||
|
||||
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt")
|
||||
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass
|
||||
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt", src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
|
||||
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> # forward pass
|
||||
>>> model(**inputs, labels=batch['labels'])
|
||||
|
||||
- Generation
|
||||
|
||||
@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt")
|
||||
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
|
||||
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
|
||||
>>> article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
|
||||
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
|
||||
Overview of MBart-50
|
||||
@ -160,7 +169,7 @@ MBartTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
|
||||
:members: as_target_tokenizer, build_inputs_with_special_tokens
|
||||
|
||||
|
||||
MBartTokenizerFast
|
||||
|
@ -78,20 +78,20 @@ Usage Example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
import torch
|
||||
src_text = [
|
||||
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
]
|
||||
>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
>>> import torch
|
||||
>>> src_text = [
|
||||
... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
>>> ]
|
||||
|
||||
model_name = 'google/pegasus-xsum'
|
||||
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
translated = model.generate(**batch)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
>>> model_name = 'google/pegasus-xsum'
|
||||
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
>>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
>>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
>>> translated = model.generate(**batch)
|
||||
>>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
>>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ PegasusTokenizer
|
||||
warning: ``add_tokens`` does not work at the moment.
|
||||
|
||||
.. autoclass:: transformers.PegasusTokenizer
|
||||
:members: __call__, prepare_seq2seq_batch
|
||||
:members:
|
||||
|
||||
|
||||
PegasusTokenizerFast
|
||||
|
@ -56,7 +56,7 @@ RagTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagTokenizer
|
||||
:members: prepare_seq2seq_batch
|
||||
:members:
|
||||
|
||||
|
||||
Rag specific outputs
|
||||
|
@ -104,7 +104,7 @@ T5Tokenizer
|
||||
|
||||
.. autoclass:: transformers.T5Tokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
T5TokenizerFast
|
||||
|
@ -71,7 +71,7 @@ tiny_model = FSMTForConditionalGeneration(config)
|
||||
print(f"num of params {tiny_model.num_parameters()}")
|
||||
|
||||
# Test
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
|
||||
batch = tokenizer(["Making tiny model"], return_tensors="pt")
|
||||
outputs = tiny_model(**batch)
|
||||
|
||||
print("test output:", len(outputs.logits[0]))
|
||||
|
@ -42,7 +42,7 @@ tiny_model = FSMTForConditionalGeneration(config)
|
||||
print(f"num of params {tiny_model.num_parameters()}")
|
||||
|
||||
# Test
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
|
||||
batch = tokenizer(["Making tiny model"], return_tensors="pt")
|
||||
outputs = tiny_model(**batch)
|
||||
|
||||
print("test output:", len(outputs.logits[0]))
|
||||
|
@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
|
||||
>>> src = 'fr' # source language
|
||||
>>> trg = 'en' # target language
|
||||
>>> sample_text = "où est l'arrêt de bus ?"
|
||||
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(mname)
|
||||
>>> tok = MarianTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> batch = tokenizer([sample_text], return_tensors="pt")
|
||||
>>> gen = model.generate(**batch)
|
||||
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
|
||||
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||
"Where is the bus stop ?"
|
||||
"""
|
||||
|
||||
MARIAN_INPUTS_DOCSTRING = r"""
|
||||
|
@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
|
||||
>>> src = 'fr' # source language
|
||||
>>> trg = 'en' # target language
|
||||
>>> sample_text = "où est l'arrêt de bus ?"
|
||||
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(mname)
|
||||
>>> tok = MarianTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="tf") # don't need tgt_text for inference
|
||||
>>> model = TFMarianMTModel.from_pretrained(model_name)
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> batch = tokenizer([sample_text], return_tensors="tf")
|
||||
>>> gen = model.generate(**batch)
|
||||
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
|
||||
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||
"Where is the bus stop ?"
|
||||
"""
|
||||
|
||||
MARIAN_INPUTS_DOCSTRING = r"""
|
||||
|
@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MarianTokenizer
|
||||
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
|
||||
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
|
||||
>>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
|
||||
>>> # keys [input_ids, attention_mask, labels].
|
||||
>>> # model(**batch) should work
|
||||
>>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
# keys [input_ids, attention_mask, labels].
|
||||
>>> outputs = model(**inputs) should work
|
||||
"""
|
||||
|
||||
vocab_files_names = vocab_files_names
|
||||
|
@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
Construct an MBART tokenizer.
|
||||
|
||||
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer` and adds a new
|
||||
:meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch`
|
||||
|
||||
Refer to superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
|
||||
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to
|
||||
superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
|
||||
initialization parameters and other methods.
|
||||
|
||||
.. warning::
|
||||
|
||||
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
|
||||
properly.
|
||||
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
|
||||
<tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
|
||||
|
||||
@property
|
||||
def src_lang(self) -> str:
|
||||
return self._src_lang
|
||||
|
||||
@src_lang.setter
|
||||
def src_lang(self, new_src_lang: str) -> None:
|
||||
self._src_lang = new_src_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
) -> BatchEncoding:
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
|
@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE
|
||||
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
|
||||
|
||||
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds
|
||||
a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`.
|
||||
|
||||
Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning
|
||||
the initialization parameters and other methods.
|
||||
|
||||
.. warning::
|
||||
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
|
||||
properly.
|
||||
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast`. Refer to
|
||||
superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning the
|
||||
initialization parameters and other methods.
|
||||
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
|
||||
<tokens> <eos>``` for target language documents.
|
||||
@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizerFast
|
||||
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
|
||||
|
||||
self.cur_lang_code = self.convert_tokens_to_ids("en_XX")
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
|
||||
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
@property
|
||||
def src_lang(self) -> str:
|
||||
return self._src_lang
|
||||
|
||||
@src_lang.setter
|
||||
def src_lang(self, new_src_lang: str) -> None:
|
||||
self._src_lang = new_src_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
) -> BatchEncoding:
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
|
@ -31,13 +31,17 @@ class MT5Model(T5Model):
|
||||
alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MT5Model, T5Tokenizer
|
||||
>>> model = MT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(**batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs,labels=labels["input_ids"])
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
|
@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model):
|
||||
documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TFMT5Model, T5Tokenizer
|
||||
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> batch["decoder_input_ids"] = batch["labels"]
|
||||
>>> del batch["labels"]
|
||||
>>> outputs = model(batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> outputs = model(batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(**inputs,labels=labels["input_ids"])
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
|
@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids)
|
||||
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> outputs = model(input_ids=inputs["input_ids"])
|
||||
"""
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
||||
>>> # or use retriever separately
|
||||
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
|
||||
@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
|
||||
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
|
||||
>>> # 3. Forward to generator
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
|
||||
"""
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
|
||||
@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
||||
>>> # or use retriever separately
|
||||
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
|
||||
@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
|
||||
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
|
||||
>>> # 3. Forward to generator
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
|
||||
|
||||
>>> # or directly generate
|
||||
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RAG."""
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
@ -88,6 +89,13 @@ class RagTokenizer:
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
warnings.warn(
|
||||
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
|
||||
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
|
||||
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
|
||||
"details",
|
||||
FutureWarning,
|
||||
)
|
||||
if max_length is None:
|
||||
max_length = self.current_tokenizer.model_max_length
|
||||
model_inputs = self(
|
||||
|
@ -3303,6 +3303,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
|
||||
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
|
||||
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
|
||||
"details",
|
||||
FutureWarning,
|
||||
)
|
||||
# mBART-specific kwargs that should be ignored by other models.
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
|
@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
|
||||
).to(torch_device)
|
||||
model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(tgt, return_tensors="pt")
|
||||
model_inputs["labels"] = targets["input_ids"].to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
|
||||
def test_unk_support(self):
|
||||
t = self.tokenizer
|
||||
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||
self.assertEqual(expected, ids)
|
||||
|
||||
def test_pad_not_split(self):
|
||||
input_ids_w_pad = (
|
||||
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
input_ids_w_pad = self.tokenizer(["I am a small frog <pad>"], return_tensors="pt").input_ids[0].tolist()
|
||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||
|
||||
|
@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
batch: BatchEncoding = self.tokenizer(
|
||||
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate_batch(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
assert self.tgt_text == decoded
|
||||
@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
inputs = self.tokenizer(["One of the best <mask> I ever read!"], return_tensors="pt").to(torch_device)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
|
@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
)
|
||||
|
@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
|
||||
)
|
||||
|
@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
|
||||
assert self.expected_text == generated_words
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
|
@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
return BartTokenizerFast.from_pretrained("facebook/bart-large")
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
def test_prepare_batch(self):
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
|
||||
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt")
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||
@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
# Test that special tokens are reset
|
||||
|
||||
# Test Prepare Seq
|
||||
@require_torch
|
||||
def test_seq2seq_batch_empty_target_text(self):
|
||||
def test_prepare_batch_empty_target_text(self):
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
|
||||
batch = tokenizer(src_text, padding=True, return_tensors="pt")
|
||||
# check if input_ids are returned and no labels
|
||||
self.assertIn("input_ids", batch)
|
||||
self.assertIn("attention_mask", batch)
|
||||
@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_seq2seq_batch_max_target_length(self):
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(32, batch["labels"].shape[1])
|
||||
|
||||
# test None max_target_length
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(32, batch["labels"].shape[1])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
def test_seq2seq_batch_not_longer_than_maxlen(self):
|
||||
def test_prepare_batch_not_longer_than_maxlen(self):
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
|
||||
batch = tokenizer(
|
||||
["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
self.assertEqual(batch.input_ids.shape, (2, 1024))
|
||||
@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Summary of the text.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
|
||||
input_ids = batch["input_ids"]
|
||||
labels = batch["labels"]
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
||||
|
@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
def test_prepare_batch(self):
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
expected_src_tokens = [0, 57, 3018, 70307, 91, 2]
|
||||
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
|
||||
batch = self.tokenizer(
|
||||
src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
|
@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_tokenizer_equivalence_en_de(self):
|
||||
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
||||
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
|
||||
batch = en_de_tokenizer(["I am a small frog"], return_tensors=None)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
expected = [38, 121, 14, 697, 38848, 0]
|
||||
self.assertListEqual(expected, batch.input_ids[0])
|
||||
@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
tok = self.get_tokenizer()
|
||||
|
||||
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
|
||||
batch = tok(
|
||||
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||
|
||||
def test_outputs_can_be_shorter(self):
|
||||
tok = self.get_tokenizer()
|
||||
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
|
||||
batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch_smaller, BatchEncoding)
|
||||
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
|
||||
|
@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(
|
||||
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
|
||||
)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||
|
||||
# prepare_seq2seq_batch tests below
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].tolist()
|
||||
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
# batch.decoder_inputs_ids[0][0] ==
|
||||
assert batch.input_ids[1][-2:] == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0] == RO_CODE
|
||||
assert batch.decoder_input_ids[1][-1] == 2
|
||||
assert batch.labels[1][-2:] == [2, RO_CODE]
|
||||
assert labels[1][-2:].tolist() == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
def test_enro_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||
self.assertEqual(ids[0], EN_CODE)
|
||||
self.assertEqual(ids[-1], 2)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||
|
||||
# prepare_seq2seq_batch tests below
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
labels = labels.tolist()
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].tolist()
|
||||
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
# batch.decoder_inputs_ids[0][0] ==
|
||||
assert batch.input_ids[1][0] == EN_CODE
|
||||
assert batch.input_ids[1][-1] == 2
|
||||
assert batch.labels[1][0] == RO_CODE
|
||||
assert batch.labels[1][-1] == 2
|
||||
assert labels[1][0] == RO_CODE
|
||||
assert labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
def test_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer.prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
|
||||
)
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "labels" in batch # because tgt_texts was specified
|
||||
assert batch.labels.shape == (2, 5)
|
||||
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
|
||||
assert targets["input_ids"].shape == (2, 5)
|
||||
assert len(batch) == 2 # input_ids, attention_mask.
|
||||
|
@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
def test_prepare_batch(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")
|
||||
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
tgt_texts=tgt_text,
|
||||
return_tensors="pt",
|
||||
)
|
||||
batch = tokenizer(src_text, padding=True, return_tensors="pt")
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
|
@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
|
||||
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
|
||||
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
def test_prepare_batch(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
tgt_texts=tgt_text,
|
||||
return_tensors=FRAMEWORK,
|
||||
)
|
||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_empty_target_text(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
|
||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||
# check if input_ids are returned and no decoder_input_ids
|
||||
self.assertIn("input_ids", batch)
|
||||
self.assertIn("attention_mask", batch)
|
||||
self.assertNotIn("decoder_input_ids", batch)
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
def test_max_target_length(self):
|
||||
def test_max_length(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, batch["labels"].shape[1])
|
||||
|
||||
# test None max_target_length
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, batch["labels"].shape[1])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
|
||||
batch = tokenizer(
|
||||
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||
@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
|
||||
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
|
||||
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
|
||||
batch = tokenizer(src_text)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text)
|
||||
|
||||
src_ids = list(batch.input_ids.numpy()[0])
|
||||
tgt_ids = list(batch.labels.numpy()[0])
|
||||
|
||||
self.assertEqual(expected_src_tokens, src_ids)
|
||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
|
||||
|
||||
def test_token_type_ids(self):
|
||||
src_text_1 = ["A first paragraph for summarization."]
|
||||
|
Loading…
Reference in New Issue
Block a user