fix RagTokenizer (#10167)

This commit is contained in:
Suraj Patil 2021-02-15 19:48:12 +05:30 committed by GitHub
parent c8d3fa0dfd
commit 2a5c990038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RAG.""" """Tokenization classes for RAG."""
import os import os
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
@ -28,6 +29,7 @@ class RagTokenizer:
def __init__(self, question_encoder, generator): def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder self.question_encoder = question_encoder
self.generator = generator self.generator = generator
self.current_tokenizer = self.question_encoder
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
@ -57,23 +59,60 @@ class RagTokenizer:
return cls(question_encoder=question_encoder, generator=generator) return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs) return self.current_tokenizer(*args, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs) return self.generator.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = self.generator
yield
self.current_tokenizer = self.question_encoder
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: if max_length is None:
max_length = self.question_encoder.model_max_length max_length = self.current_tokenizer.model_max_length
if max_target_length is None: model_inputs = self(
max_target_length = self.generator.model_max_length src_texts,
return super().prepare_seq2seq_batch( add_special_tokens=True,
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
) )
if tgt_texts is None:
return model_inputs
# Process tgt_texts
with self.as_target_tokenizer():
if max_target_length is None:
max_target_length = self.current_tokenizer.model_max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs