mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
fix RagTokenizer (#10167)
This commit is contained in:
parent
c8d3fa0dfd
commit
2a5c990038
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes for RAG."""
|
"""Tokenization classes for RAG."""
|
||||||
import os
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
@ -28,6 +29,7 @@ class RagTokenizer:
|
|||||||
def __init__(self, question_encoder, generator):
|
def __init__(self, question_encoder, generator):
|
||||||
self.question_encoder = question_encoder
|
self.question_encoder = question_encoder
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
|
self.current_tokenizer = self.question_encoder
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
@ -57,23 +59,60 @@ class RagTokenizer:
|
|||||||
return cls(question_encoder=question_encoder, generator=generator)
|
return cls(question_encoder=question_encoder, generator=generator)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.question_encoder(*args, **kwargs)
|
return self.current_tokenizer(*args, **kwargs)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
return self.generator.batch_decode(*args, **kwargs)
|
return self.generator.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
return self.generator.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_target_tokenizer(self):
|
||||||
|
"""
|
||||||
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
|
"""
|
||||||
|
self.current_tokenizer = self.generator
|
||||||
|
yield
|
||||||
|
self.current_tokenizer = self.question_encoder
|
||||||
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
max_target_length: Optional[int] = None,
|
max_target_length: Optional[int] = None,
|
||||||
|
padding: str = "longest",
|
||||||
|
return_tensors: str = None,
|
||||||
|
truncation: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.question_encoder.model_max_length
|
max_length = self.current_tokenizer.model_max_length
|
||||||
if max_target_length is None:
|
model_inputs = self(
|
||||||
max_target_length = self.generator.model_max_length
|
src_texts,
|
||||||
return super().prepare_seq2seq_batch(
|
add_special_tokens=True,
|
||||||
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
return_tensors=return_tensors,
|
||||||
|
max_length=max_length,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if tgt_texts is None:
|
||||||
|
return model_inputs
|
||||||
|
# Process tgt_texts
|
||||||
|
with self.as_target_tokenizer():
|
||||||
|
if max_target_length is None:
|
||||||
|
max_target_length = self.current_tokenizer.model_max_length
|
||||||
|
labels = self(
|
||||||
|
tgt_texts,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_target_length,
|
||||||
|
truncation=truncation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
return model_inputs
|
||||||
|
Loading…
Reference in New Issue
Block a user