diff --git a/src/transformers/models/rag/tokenization_rag.py b/src/transformers/models/rag/tokenization_rag.py index 03ae2f68a87..7b5916b78dd 100644 --- a/src/transformers/models/rag/tokenization_rag.py +++ b/src/transformers/models/rag/tokenization_rag.py @@ -14,6 +14,7 @@ # limitations under the License. """Tokenization classes for RAG.""" import os +from contextlib import contextmanager from typing import List, Optional from ...tokenization_utils_base import BatchEncoding @@ -28,6 +29,7 @@ class RagTokenizer: def __init__(self, question_encoder, generator): self.question_encoder = question_encoder self.generator = generator + self.current_tokenizer = self.question_encoder def save_pretrained(self, save_directory): if os.path.isfile(save_directory): @@ -57,23 +59,60 @@ class RagTokenizer: return cls(question_encoder=question_encoder, generator=generator) def __call__(self, *args, **kwargs): - return self.question_encoder(*args, **kwargs) + return self.current_tokenizer(*args, **kwargs) def batch_decode(self, *args, **kwargs): return self.generator.batch_decode(*args, **kwargs) + def decode(self, *args, **kwargs): + return self.generator.decode(*args, **kwargs) + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.current_tokenizer = self.generator + yield + self.current_tokenizer = self.question_encoder + def prepare_seq2seq_batch( self, src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = None, + truncation: bool = True, **kwargs, ) -> BatchEncoding: if max_length is None: - max_length = self.question_encoder.model_max_length - if max_target_length is None: - max_target_length = self.generator.model_max_length - return super().prepare_seq2seq_batch( - src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs + max_length = self.current_tokenizer.model_max_length + model_inputs = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + with self.as_target_tokenizer(): + if max_target_length is None: + max_target_length = self.current_tokenizer.model_max_length + labels = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs