fix RagTokenizer (#10167)

This commit is contained in:
Suraj Patil 2021-02-15 19:48:12 +05:30 committed by GitHub
parent c8d3fa0dfd
commit 2a5c990038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,7 @@
# limitations under the License.
"""Tokenization classes for RAG."""
import os
from contextlib import contextmanager
from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding
@ -28,6 +29,7 @@ class RagTokenizer:
def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder
self.generator = generator
self.current_tokenizer = self.question_encoder
def save_pretrained(self, save_directory):
if os.path.isfile(save_directory):
@ -57,23 +59,60 @@ class RagTokenizer:
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs)
return self.current_tokenizer(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = self.generator
yield
self.current_tokenizer = self.question_encoder
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.question_encoder.model_max_length
if max_target_length is None:
max_target_length = self.generator.model_max_length
return super().prepare_seq2seq_batch(
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
max_length = self.current_tokenizer.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
with self.as_target_tokenizer():
if max_target_length is None:
max_target_length = self.current_tokenizer.model_max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs