mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
fix RagTokenizer (#10167)
This commit is contained in:
parent
c8d3fa0dfd
commit
2a5c990038
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RAG."""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
@ -28,6 +29,7 @@ class RagTokenizer:
|
||||
def __init__(self, question_encoder, generator):
|
||||
self.question_encoder = question_encoder
|
||||
self.generator = generator
|
||||
self.current_tokenizer = self.question_encoder
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
if os.path.isfile(save_directory):
|
||||
@ -57,23 +59,60 @@ class RagTokenizer:
|
||||
return cls(question_encoder=question_encoder, generator=generator)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.question_encoder(*args, **kwargs)
|
||||
return self.current_tokenizer(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.generator.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
return self.generator.decode(*args, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.current_tokenizer = self.generator
|
||||
yield
|
||||
self.current_tokenizer = self.question_encoder
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.question_encoder.model_max_length
|
||||
if max_target_length is None:
|
||||
max_target_length = self.generator.model_max_length
|
||||
return super().prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
||||
max_length = self.current_tokenizer.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
with self.as_target_tokenizer():
|
||||
if max_target_length is None:
|
||||
max_target_length = self.current_tokenizer.model_max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
return model_inputs
|
||||
|
Loading…
Reference in New Issue
Block a user