mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Rag] Fix loading of pretrained Rag Tokenizer (#7756)
* fix rag * Update tokenizer save_pretrained Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
parent
2d4e928d97
commit
82b09a8481
@ -1637,9 +1637,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
|
||||
special_tokens_map = convert_added_tokens(special_tokens_map)
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
elif isinstance(value, list):
|
||||
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
|
||||
setattr(tokenizer, key, value)
|
||||
|
||||
# Add supplementary tokens.
|
||||
@ -1706,23 +1708,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
tokenizer_config.pop(file_id, None)
|
||||
|
||||
# Sanitize AddedTokens
|
||||
def convert_added_tokens(obj: Union[AddedToken, Any]):
|
||||
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
|
||||
if isinstance(obj, AddedToken):
|
||||
out = obj.__getstate__()
|
||||
out["__type"] = "AddedToken"
|
||||
if add_type_field:
|
||||
out["__type"] = "AddedToken"
|
||||
return out
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return list(convert_added_tokens(o) for o in obj)
|
||||
return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_added_tokens(v) for k, v in obj.items()}
|
||||
return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config)
|
||||
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
||||
# Sanitize AddedTokens in special_tokens_map
|
||||
write_dict = convert_added_tokens(self.special_tokens_map_extended)
|
||||
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
|
||||
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(write_dict, ensure_ascii=False))
|
||||
|
||||
|
@ -7,7 +7,7 @@ from unittest import TestCase
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch
|
||||
from transformers.testing_utils import require_datasets, require_faiss, require_torch, slow
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
@ -108,3 +108,49 @@ class RagTokenizerTest(TestCase):
|
||||
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
|
||||
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
|
||||
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
|
||||
|
||||
@slow
|
||||
def test_pretrained_token_nq_tokenizer(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||
input_strings = [
|
||||
"who got the first nobel prize in physics",
|
||||
"when is the next deadpool movie being released",
|
||||
"which mode is used for short wave broadcast service",
|
||||
"who is the owner of reading football club",
|
||||
"when is the next scandal episode coming out",
|
||||
"when is the last time the philadelphia won the superbowl",
|
||||
"what is the most current adobe flash player version",
|
||||
"how many episodes are there in dragon ball z",
|
||||
"what is the first step in the evolution of the eye",
|
||||
"where is gall bladder situated in human body",
|
||||
"what is the main mineral in lithium batteries",
|
||||
"who is the president of usa right now",
|
||||
"where do the greasers live in the outsiders",
|
||||
"panda is a national animal of which country",
|
||||
"what is the name of manchester united stadium",
|
||||
]
|
||||
input_dict = tokenizer(input_strings)
|
||||
self.assertIsNotNone(input_dict)
|
||||
|
||||
@slow
|
||||
def test_pretrained_sequence_nq_tokenizer(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
input_strings = [
|
||||
"who got the first nobel prize in physics",
|
||||
"when is the next deadpool movie being released",
|
||||
"which mode is used for short wave broadcast service",
|
||||
"who is the owner of reading football club",
|
||||
"when is the next scandal episode coming out",
|
||||
"when is the last time the philadelphia won the superbowl",
|
||||
"what is the most current adobe flash player version",
|
||||
"how many episodes are there in dragon ball z",
|
||||
"what is the first step in the evolution of the eye",
|
||||
"where is gall bladder situated in human body",
|
||||
"what is the main mineral in lithium batteries",
|
||||
"who is the president of usa right now",
|
||||
"where do the greasers live in the outsiders",
|
||||
"panda is a national animal of which country",
|
||||
"what is the name of manchester united stadium",
|
||||
]
|
||||
input_dict = tokenizer(input_strings)
|
||||
self.assertIsNotNone(input_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user