diff --git a/src/transformers/models/bart/tokenization_bart.py b/src/transformers/models/bart/tokenization_bart.py index 57a8e2448b1..4a468d811e6 100644 --- a/src/transformers/models/bart/tokenization_bart.py +++ b/src/transformers/models/bart/tokenization_bart.py @@ -13,10 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - -from ...file_utils import add_start_docstrings -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging from ..roberta.tokenization_roberta import RobertaTokenizer @@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer): "vocab_file": {m: vocab_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models}, } - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = None, - truncation=True, - **kwargs, - ) -> BatchEncoding: - kwargs.pop("src_lang", None) - kwargs.pop("tgt_lang", None) - if max_length is None: - max_length = self.model_max_length - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - return model_inputs diff --git a/src/transformers/models/bart/tokenization_bart_fast.py b/src/transformers/models/bart/tokenization_bart_fast.py index 19678f9d521..87ae6158216 100644 --- a/src/transformers/models/bart/tokenization_bart_fast.py +++ b/src/transformers/models/bart/tokenization_bart_fast.py @@ -13,10 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - -from ...file_utils import add_start_docstrings -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast from .tokenization_bart import BartTokenizer @@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast): "tokenizer_file": {m: tokenizer_url for m in _all_bart_models}, } slow_tokenizer_class = BartTokenizer - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: Optional[str] = None, - truncation=True, - **kwargs, - ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - return model_inputs diff --git a/src/transformers/models/barthez/tokenization_barthez.py b/src/transformers/models/barthez/tokenization_barthez.py index a9774ba0ba8..3096f485ad5 100644 --- a/src/transformers/models/barthez/tokenization_barthez.py +++ b/src/transformers/models/barthez/tokenization_barthez.py @@ -21,9 +21,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging @@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,) - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = "None", - truncation=True, - **kwargs, - ) -> BatchEncoding: - kwargs.pop("src_lang", None) - kwargs.pop("tgt_lang", None) - if max_length is None: - max_length = self.model_max_length - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - return model_inputs diff --git a/src/transformers/models/barthez/tokenization_barthez_fast.py b/src/transformers/models/barthez/tokenization_barthez_fast.py index c0870778118..c8391dea80e 100644 --- a/src/transformers/models/barthez/tokenization_barthez_fast.py +++ b/src/transformers/models/barthez/tokenization_barthez_fast.py @@ -19,8 +19,7 @@ import os from shutil import copyfile from typing import List, Optional, Tuple -from ...file_utils import add_start_docstrings, is_sentencepiece_available -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding +from ...file_utils import is_sentencepiece_available from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,) - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = "None", - truncation=True, - **kwargs, - ) -> BatchEncoding: - kwargs.pop("src_lang", None) - kwargs.pop("tgt_lang", None) - if max_length is None: - max_length = self.model_max_length - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - return model_inputs diff --git a/src/transformers/models/fsmt/tokenization_fsmt.py b/src/transformers/models/fsmt/tokenization_fsmt.py index 71bfd93000f..30b4e5afd95 100644 --- a/src/transformers/models/fsmt/tokenization_fsmt.py +++ b/src/transformers/models/fsmt/tokenization_fsmt.py @@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple import sacremoses as sm -from ...file_utils import add_start_docstrings -from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging @@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer): return len(token_ids_0 + sep) * [0] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - return_tensors: Optional[str] = None, - truncation=True, - padding="longest", - **unused, - ) -> BatchEncoding: - if type(src_texts) is not list: - raise ValueError("src_texts is expected to be a list") - if "" in src_texts: - raise ValueError(f"found empty string in src_texts: {src_texts}") - - tokenizer_kwargs = dict( - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - truncation=truncation, - padding=padding, - ) - model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) - - if tgt_texts is None: - return model_inputs - if max_target_length is not None: - tokenizer_kwargs["max_length"] = max_target_length - - model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"] - return model_inputs - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index 0fb03c53855..0d4dea27425 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -15,15 +15,14 @@ import json import re import warnings +from contextlib import contextmanager from pathlib import Path from shutil import copyfile from typing import Dict, List, Optional, Tuple, Union import sentencepiece -from ...file_utils import add_start_docstrings -from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from ...tokenization_utils import PreTrainedTokenizer vocab_files_names = { @@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + [self.eos_token_id] - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - return_tensors: Optional[str] = None, - truncation=True, - padding="longest", - **unused, - ) -> BatchEncoding: - if "" in src_texts: - raise ValueError(f"found empty string in src_texts: {src_texts}") - self.current_spm = self.spm_source - src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much - tokenizer_kwargs = dict( - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - truncation=truncation, - padding=padding, - ) - model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) - - if tgt_texts is None: - return model_inputs - if max_target_length is not None: - tokenizer_kwargs["max_length"] = max_target_length - + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ self.current_spm = self.spm_target - model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"] + yield self.current_spm = self.spm_source - return model_inputs @property def vocab_size(self) -> int: diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py index e8425fe8c53..8b88c98e680 100644 --- a/src/transformers/models/mbart/tokenization_mbart.py +++ b/src/transformers/models/mbart/tokenization_mbart.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional -from ...file_utils import add_start_docstrings from ...tokenization_utils import BatchEncoding -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING from ...utils import logging from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer @@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, src_texts: List[str], src_lang: str = "en_XX", tgt_texts: Optional[List[str]] = None, tgt_lang: str = "ro_RO", - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - truncation: bool = True, - padding: str = "longest", - return_tensors: Optional[str] = None, - add_prefix_space: bool = False, # ignored **kwargs, ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - self.set_src_lang_special_tokens(src_lang) - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - self.set_tgt_lang_special_tokens(tgt_lang) + self.src_lang = src_lang + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self.src_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=True, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - self.set_src_lang_special_tokens(src_lang) # sets to src_lang - return model_inputs + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) def set_src_lang_special_tokens(self, src_lang) -> None: """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py index 56e7c065f39..80e0efed804 100644 --- a/src/transformers/models/mbart/tokenization_mbart_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart_fast.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional from tokenizers import processors -from ...file_utils import add_start_docstrings, is_sentencepiece_available +from ...file_utils import is_sentencepiece_available from ...tokenization_utils import BatchEncoding -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING from ...utils import logging from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast @@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, src_texts: List[str], src_lang: str = "en_XX", tgt_texts: Optional[List[str]] = None, tgt_lang: str = "ro_RO", - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - truncation: bool = True, - padding: str = "longest", - return_tensors: str = None, **kwargs, ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - self.set_src_lang_special_tokens(src_lang) - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - self.set_tgt_lang_special_tokens(tgt_lang) + self.src_lang = src_lang + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self.src_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=True, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - self.set_src_lang_special_tokens(src_lang) # sets to src_lang - return model_inputs + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) def set_src_lang_special_tokens(self, src_lang) -> None: """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index 099bdf3e7b3..c7d39a332c6 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple import sentencepiece as spm -from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging @@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + [self.eos_token_id] - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - return_tensors: str = None, - truncation=True, - padding="longest", - **unused, - ) -> BatchEncoding: - if "" in src_texts: - raise ValueError(f"found empty string in src_texts: {src_texts}") - tokenizer_kwargs = dict( - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - truncation=truncation, - padding=padding, - ) - model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) - if tgt_texts is None: - return model_inputs - if max_target_length is not None: - tokenizer_kwargs["max_length"] = max_target_length - labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"] - model_inputs["labels"] = labels - return model_inputs - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py index c9b0d076314..f967ad4d480 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py +++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -19,8 +19,7 @@ import os from shutil import copyfile from typing import List, Optional, Tuple -from ...file_utils import add_start_docstrings, is_sentencepiece_available -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding +from ...file_utils import is_sentencepiece_available from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + [self.eos_token_id] - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - return_tensors: str = None, - truncation=True, - padding="longest", - **unused, - ) -> BatchEncoding: - if "" in src_texts: - raise ValueError(f"found empty string in src_texts: {src_texts}") - tokenizer_kwargs = dict( - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - truncation=truncation, - padding=padding, - ) - model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) - if tgt_texts is None: - return model_inputs - if max_target_length is not None: - tokenizer_kwargs["max_length"] = max_target_length - labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"] - model_inputs["labels"] = labels - return model_inputs - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py index 5d93a00e852..213e303a88b 100644 --- a/src/transformers/models/prophetnet/tokenization_prophetnet.py +++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py @@ -17,9 +17,7 @@ import collections import os from typing import List, Optional, Tuple -from ...file_utils import add_start_docstrings -from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer @@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer): return token_ids_0 + [self.sep_token_id] sep = [self.sep_token_id] return token_ids_0 + sep + token_ids_1 + sep - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = None, - truncation: bool = True, - **kwargs, - ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - model_inputs = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels_and_decoder_mask = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - ) - model_inputs["labels"] = labels_and_decoder_mask["input_ids"] - return model_inputs diff --git a/src/transformers/models/rag/tokenization_rag.py b/src/transformers/models/rag/tokenization_rag.py index 766d04662d7..03ae2f68a87 100644 --- a/src/transformers/models/rag/tokenization_rag.py +++ b/src/transformers/models/rag/tokenization_rag.py @@ -16,8 +16,7 @@ import os from typing import List, Optional -from ...file_utils import add_start_docstrings -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding +from ...tokenization_utils_base import BatchEncoding from ...utils import logging from .configuration_rag import RagConfig @@ -63,42 +62,18 @@ class RagTokenizer: def batch_decode(self, *args, **kwargs): return self.generator.batch_decode(*args, **kwargs) - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = None, - truncation=True, **kwargs, ) -> BatchEncoding: if max_length is None: max_length = self.question_encoder.model_max_length - model_inputs: BatchEncoding = self.question_encoder( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts if max_target_length is None: max_target_length = self.generator.model_max_length - labels = self.generator( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - )["input_ids"] - model_inputs["labels"] = labels - return model_inputs + return super().prepare_seq2seq_batch( + src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs + ) diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index c4d57e0ac19..dae98df3a0d 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -23,9 +23,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...file_utils import add_start_docstrings -from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging @@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,) - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = None, - truncation: bool = True, - **kwargs, - ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - model_inputs = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - labels_and_decoder_mask = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - ) - model_inputs["labels"] = labels_and_decoder_mask["input_ids"] - return model_inputs diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index e4ad4a30664..f3b3fd6dcf6 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -19,9 +19,7 @@ import os from shutil import copyfile from typing import List, Optional, Tuple -from ...file_utils import add_start_docstrings, is_sentencepiece_available -from ...tokenization_utils import BatchEncoding -from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from ...file_utils import is_sentencepiece_available from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast): if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] - - @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = None, - truncation: bool = True, - **kwargs, - ) -> BatchEncoding: - if max_length is None: - max_length = self.model_max_length - self.prefix_tokens = [] - model_inputs = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - # set prefix_tokens for target text - self.prefix_tokens = [self.pad_token_id] - labels_and_decoder_mask = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - ) - model_inputs["labels"] = labels_and_decoder_mask["input_ids"] - self.prefix_tokens = [] - return model_inputs diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index d6212ae0b68..5f10f4b6f49 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): return clean_text else: return text - - def prepare_seq2seq_batch( - self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: str = "None", - truncation=True, - **kwargs, - ) -> BatchEncoding: - r""" - - Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`. - - Args: - src_texts: (:obj:`List[str]`): - List of documents to summarize or source language texts. - tgt_texts: (:obj:`List[str]`, `optional`): - List of summaries or target language texts. - max_length (:obj:`int`, `optional`): - Controls the maximum length for encoder inputs (documents to summarize or source language texts). If - left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length - is required by one of the truncation/padding parameters. If the model has no specific maximum input - length (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (:obj:`int`, `optional`): - Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or - set to :obj:`None`, this will use the max_length value. - padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): - Activates and controls padding. Accepts the following values: - - * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a - single sequence if provided). - * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of - different lengths). - return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): - If set, will return tensors instead of list of python integers. Acceptable values are: - - * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. - * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. - truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): - Activates and controls truncation. Accepts the following values: - - * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument - :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not - provided. This will truncate token by token, removing a token from the longest sequence in the pair - if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to - the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with - sequence lengths greater than the model maximum admissible input size). - **kwargs: - Additional keyword arguments passed along to :obj:`self.__call__`. - - Returns: - :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **labels** -- List of token ids for tgt_texts - - The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. - Otherwise, input_ids, attention_mask will be the only keys. - """ - raise NotImplementedError( - "If your model requires more than input_ids for a typical forward pass, you should implement this method. " - "Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a " - "reference implementation." - ) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 05ac93eefcb..11d4dce5414 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -23,6 +23,7 @@ import json import os import warnings from collections import OrderedDict, UserDict +from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r""" """ -PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """ - Prepare model inputs for translation. For best performance, translate one sentence at a time. - - Arguments: - src_texts (:obj:`List[str]`): - List of documents to summarize or source language texts. - tgt_texts (:obj:`list`, `optional`): - List of summaries or target language texts. - max_length (:obj:`int`, `optional`): - Controls the maximum length for encoder inputs (documents to summarize or source language texts) If - left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length - is required by one of the truncation/padding parameters. If the model has no specific maximum input - length (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (:obj:`int`, `optional`): - Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set - to :obj:`None`, this will use the max_length value. - padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): - Activates and controls padding. Accepts the following values: - - * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a - single sequence if provided). - * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of - different lengths). - return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): - If set, will return tensors instead of list of python integers. Acceptable values are: - - * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. - * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. - truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): - Activates and controls truncation. Accepts the following values: - - * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument - :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not - provided. This will truncate token by token, removing a token from the longest sequence in the pair - if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to - the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with - sequence lengths greater than the model maximum admissible input size). - **kwargs: - Additional keyword arguments passed along to :obj:`self.__call__`. - - Return: - :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **labels** -- List of token ids for tgt_texts. - - The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. - Otherwise, input_ids, attention_mask will be the only keys. - -""" - - @add_end_docstrings(INIT_TOKENIZER_DOCSTRING) class PreTrainedTokenizerBase(SpecialTokensMixin): """ @@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): "indexing errors".format(len(ids), self.model_max_length) ) self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + yield + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = None, + truncation: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare model inputs for translation. For best performance, translate one sentence at a time. + + Arguments: + src_texts (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts (:obj:`list`, `optional`): + List of summaries or target language texts. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) If + left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set + to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to :obj:`self.__call__`. + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **labels** -- List of token ids for tgt_texts. + + The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. + Otherwise, input_ids, attention_mask will be the only keys. + """ + # mBART-specific kwargs that should be ignored by other models. + kwargs.pop("src_lang", None) + kwargs.pop("tgt_lang", None) + if max_length is None: + max_length = self.model_max_length + model_inputs = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + with self.as_target_tokenizer(): + labels = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index eee61080779..e5a815d963f 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): def test_batch_generation_en_ROMANCE_multi(self): self._assert_generated_batch_equal_expected() - def test_tokenizer_handles_empty(self): - normalized = self.tokenizer.normalize("") - self.assertIsInstance(normalized, str) - with self.assertRaises(ValueError): - self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt") - @slow def test_pipeline(self): device = 0 if torch_device == "cuda" else -1 diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index dd4ae1a7298..9462c86b7b9 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -83,6 +83,7 @@ class TokenizerTesterMixin: from_pretrained_kwargs = None from_pretrained_filter = None from_pretrained_vocab_key = "vocab_file" + test_seq2seq = True def setUp(self) -> None: # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the @@ -1799,10 +1800,11 @@ class TokenizerTesterMixin: @require_torch def test_prepare_seq2seq_batch(self): + if not self.test_seq2seq: + return + tokenizer = self.get_tokenizer() - if not hasattr(tokenizer, "prepare_seq2seq_batch"): - return # Longer text that will definitely require truncation. src_text = [ " UN Chief Says There Is No Military Solution in Syria", diff --git a/tests/test_tokenization_ctrl.py b/tests/test_tokenization_ctrl.py index 435e1f3bb40..f4cd52d6011 100644 --- a/tests/test_tokenization_ctrl.py +++ b/tests/test_tokenization_ctrl.py @@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = CTRLTokenizer test_rust_tokenizer = False + test_seq2seq = False def setUp(self): super().setUp() diff --git a/tests/test_tokenization_gpt2.py b/tests/test_tokenization_gpt2.py index b1003573101..fcc0162738b 100644 --- a/tests/test_tokenization_gpt2.py +++ b/tests/test_tokenization_gpt2.py @@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): rust_tokenizer_class = GPT2TokenizerFast test_rust_tokenizer = True from_pretrained_kwargs = {"add_prefix_space": True} + test_seq2seq = False def setUp(self): super().setUp() diff --git a/tests/test_tokenization_openai.py b/tests/test_tokenization_openai.py index ad6fbb0715f..a8180d65cc5 100644 --- a/tests/test_tokenization_openai.py +++ b/tests/test_tokenization_openai.py @@ -31,6 +31,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = OpenAIGPTTokenizer rust_tokenizer_class = OpenAIGPTTokenizerFast test_rust_tokenizer = True + test_seq2seq = False def setUp(self): super().setUp() diff --git a/tests/test_tokenization_reformer.py b/tests/test_tokenization_reformer.py index f2f1bc49074..3a9c3d04e13 100644 --- a/tests/test_tokenization_reformer.py +++ b/tests/test_tokenization_reformer.py @@ -33,6 +33,7 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = ReformerTokenizer rust_tokenizer_class = ReformerTokenizerFast test_rust_tokenizer = True + test_seq2seq = False def setUp(self): super().setUp() diff --git a/tests/test_tokenization_tapas.py b/tests/test_tokenization_tapas.py index 064be0d4b90..75dc81af404 100644 --- a/tests/test_tokenization_tapas.py +++ b/tests/test_tokenization_tapas.py @@ -44,6 +44,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): test_rust_tokenizer = False space_between_special_tokens = True from_pretrained_filter = filter_non_english + test_seq2seq = False def get_table( self, diff --git a/tests/test_tokenization_transfo_xl.py b/tests/test_tokenization_transfo_xl.py index 6d8fa0aad46..fab36948445 100644 --- a/tests/test_tokenization_transfo_xl.py +++ b/tests/test_tokenization_transfo_xl.py @@ -26,6 +26,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = TransfoXLTokenizer test_rust_tokenizer = False + test_seq2seq = False def setUp(self): super().setUp()