import json import os import shutil import tempfile from unittest import TestCase from transformers.configuration_bart import BartConfig from transformers.configuration_dpr import DPRConfig from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available from transformers.testing_utils import require_datasets, require_faiss, require_torch from transformers.tokenization_bart import BartTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES if is_torch_available() and is_datasets_available() and is_faiss_available(): from transformers.configuration_rag import RagConfig from transformers.tokenization_rag import RagTokenizer @require_faiss @require_datasets @require_torch class RagTokenizerTest(TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() self.retrieval_vector_size = 8 # DPR tok vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") os.makedirs(dpr_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) # BART tok vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "", ] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] self.special_tokens_map = {"unk_token": ""} bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") os.makedirs(bart_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) with open(self.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) def get_bart_tokenizer(self) -> BartTokenizer: return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) def tearDown(self): shutil.rmtree(self.tmpdirname) def test_save_load_pretrained_with_saved_config(self): save_dir = os.path.join(self.tmpdirname, "rag_tokenizer") rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict()) rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer()) rag_config.save_pretrained(save_dir) rag_tokenizer.save_pretrained(save_dir) new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config) self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer) self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab) self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer) self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)