from __future__ import annotations import json import os import shutil import tempfile import unittest from unittest.mock import patch import numpy as np from transformers import BartTokenizer from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow from transformers.utils import cached_property, is_datasets_available, is_faiss_available, is_tf_available if is_tf_available() and is_datasets_available() and is_faiss_available(): import faiss import tensorflow as tf from datasets import Dataset from transformers import ( AutoConfig, RagConfig, RagRetriever, RagTokenizer, TFAutoModel, TFAutoModelForSeq2SeqLM, TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration, ) from transformers.modeling_tf_outputs import TFBaseModelOutput from ..bart.test_modeling_tf_bart import TFBartModelTester from ..dpr.test_modeling_tf_dpr import TFDPRModelTester TOLERANCE = 1e-3 def require_retrieval(test_case): """ Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with [`RagRetriever`]. These tests are skipped when respective libraries are not installed. """ if not (is_tf_available() and is_datasets_available() and is_faiss_available()): test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case) return test_case @require_tf @require_retrieval @require_sentencepiece class TFRagTestMixin: all_model_classes = ( (TFRagModel, TFRagTokenForGeneration, TFRagSequenceForGeneration) if is_tf_available() and is_datasets_available() and is_faiss_available() else () ) all_generative_model_classes = ( (TFRagTokenForGeneration, TFRagSequenceForGeneration) if is_tf_available() and is_datasets_available() and is_faiss_available() else () ) retrieval_vector_size = 32 n_docs = 3 max_combined_length = 16 def setUp(self): self.tmpdirname = tempfile.mkdtemp() # DPR tok vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") os.makedirs(dpr_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) # BART tok vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "", ] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] self.special_tokens_map = {"unk_token": ""} bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") os.makedirs(bart_tokenizer_path, exist_ok=True) self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) with open(self.vocab_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(vocab_tokens) + "\n") with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) @cached_property def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) @cached_property def bart_tokenizer(self) -> BartTokenizer: return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) def tearDown(self): shutil.rmtree(self.tmpdirname) def get_retriever(self, config): dataset = Dataset.from_dict( { "id": ["0", "1", "3"], "text": ["foo", "bar", "qux"], "title": ["Foo", "Bar", "Qux"], "embeddings": [ np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size), 3 * np.ones(self.retrieval_vector_size), ], } ) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) tokenizer = self.bart_tokenizer with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.dpr_tokenizer, generator_tokenizer=tokenizer, ) return retriever def check_model_with_retriever( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class(config, retriever=self.get_retriever(config)) self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_generate_from_context_input_ids( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for i, model_class in enumerate(self.all_generative_model_classes): model = model_class(config) self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.numpy(), prefix=config.generator.prefix, return_tensors="tf", ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) # compute doc_scores doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), axis=[1], ) outputs = model.generate( context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, ) self.assertIsNotNone(outputs) def check_model_generate( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_generative_model_classes: model = model_class(config, retriever=self.get_retriever(config)) self.assertTrue(model.config.is_encoder_decoder) input_ids = tf.cast(input_ids, tf.int32) outputs = model.generate( input_ids=input_ids, num_beams=2, num_return_sequences=2, decoder_start_token_id=config.generator.eos_token_id, max_new_tokens=5, ) self.assertIsNotNone(outputs) def check_model_without_retriever( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config) self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.numpy(), prefix=config.generator.prefix, return_tensors="tf", ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) # compute doc_scores doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), axis=[1], ) outputs = model( input_ids=None, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def check_model_custom_n_docs( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config) self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.numpy(), prefix=config.generator.prefix, return_tensors="tf", n_docs=n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) # compute doc_scores doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), axis=[1], ) outputs = model( input_ids=None, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=n_docs, ) # logits self.assertEqual( outputs.logits.shape, (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs)) def check_model_with_mismatch_n_docs_value( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, retriever_n_docs, generator_n_docs, **kwargs, ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) retriever = self.get_retriever(config) for model_class in self.all_model_classes: model = model_class(config) self.assertTrue(model.config.is_encoder_decoder) question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] out = retriever( input_ids, question_hidden_states.numpy(), prefix=config.generator.prefix, return_tensors="tf", n_docs=retriever_n_docs, ) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) # compute doc_scores doc_scores = tf.squeeze( tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True), axis=[1], ) self.assertRaises( AssertionError, model.__call__, input_ids=None, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, n_docs=generator_n_docs, ) def check_model_with_encoder_outputs( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class(config, retriever=self.get_retriever(config)) self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) encoder_outputs = TFBaseModelOutput(outputs.generator_enc_last_hidden_state) # run only generator outputs = model( input_ids=None, encoder_outputs=encoder_outputs, doc_scores=outputs.doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) def test_model_with_retriever(self): inputs_dict = self.config_and_inputs self.check_model_with_retriever(**inputs_dict) def test_model_without_retriever(self): inputs_dict = self.config_and_inputs self.check_model_without_retriever(**inputs_dict) @slow def test_model_generate_from_context_input_ids(self): inputs_dict = self.config_and_inputs self.check_model_generate_from_context_input_ids(**inputs_dict) def test_model_with_encoder_outputs(self): inputs_dict = self.config_and_inputs self.check_model_with_encoder_outputs(**inputs_dict) @slow def test_model_generate(self): inputs_dict = self.config_and_inputs self.check_model_generate(**inputs_dict) def test_model_with_custom_n_docs(self): inputs_dict = self.config_and_inputs inputs_dict["n_docs"] = 1 self.check_model_custom_n_docs(**inputs_dict) def test_model_with_mismatch_n_docs_value(self): inputs_dict = self.config_and_inputs inputs_dict["retriever_n_docs"] = 3 inputs_dict["generator_n_docs"] = 2 self.check_model_with_mismatch_n_docs_value(**inputs_dict) @require_tf @require_retrieval class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase): @cached_property def config_and_inputs(self): question_encoder_tester = TFDPRModelTester(self) dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs() generator_tester = TFBartModelTester(self) bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common() (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs (generator_config, bart_inputs_dict) = bart_config_and_inputs decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"] config = RagConfig.from_question_encoder_generator_configs( question_encoder_config, generator_config, n_docs=self.n_docs, retrieval_vector_size=self.retrieval_vector_size, max_combined_length=self.max_combined_length, ) return { "config": config, "input_ids": input_ids, "attention_mask": input_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, } @require_tf @require_retrieval @require_sentencepiece @require_tokenizers class TFRagModelIntegrationTests(unittest.TestCase): @cached_property def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" ) @cached_property def sequence_model(self): return TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn" ) def token_model_nq_checkpoint(self, retriever): return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) def get_rag_config(self): question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") return RagConfig.from_question_encoder_generator_configs( question_encoder_config, generator_config, bos_token_id=0, decoder_start_token_id=2, eos_token_id=2, is_encoder_decoder=True, pad_token_id=1, vocab_size=50264, title_sep=" / ", doc_sep=" // ", n_docs=5, max_combined_length=300, dataset="wiki_dpr", dataset_split="train", index_name="exact", index_path=None, use_dummy_dataset=True, retrieval_vector_size=768, retrieval_batch_size=8, dataset_revision="b24a417", ) @slow def test_rag_sequence_inference(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_sequence = self.sequence_model rag_sequence.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids output = rag_sequence( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.7368]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) @slow def test_rag_token_inference(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.3557]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) @slow def test_rag_token_inference_nq_checkpoint(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50265]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor([[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]]) expected_loss = tf.convert_to_tensor([32.521812]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) @slow def test_rag_token_inference_save_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_token = self.token_model rag_token.set_retriever(rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids # model must run once to be functional before loading/saving works rag_token( input_ids, labels=decoder_input_ids, ) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: rag_token.save_pretrained(tmpdirname) rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) output = rag_token( input_ids, labels=decoder_input_ids, ) expected_shape = tf.TensorShape([5, 5, 50264]) self.assertEqual(output.logits.shape, expected_shape) expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]) expected_loss = tf.convert_to_tensor([36.3557]) tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3) tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3) @slow def test_init_and_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base") rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids rag( input_ids, decoder_input_ids=decoder_input_ids, ) # this should not give any warnings with tempfile.TemporaryDirectory() as tmpdirname: rag.save_pretrained(tmpdirname) rag = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever) @property def test_data_questions(self): return [ "who got the first nobel prize in physics", "when is the next deadpool movie being released", "which mode is used for short wave broadcast service", "who is the owner of reading football club", "when is the next scandal episode coming out", "when is the last time the philadelphia won the superbowl", "what is the most current adobe flash player version", "how many episodes are there in dragon ball z", ] @slow def test_rag_token_greedy_search(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417" ) rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) # check first two questions input_dict = tokenizer( self.test_data_questions[:2], return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask # make sure only 1 beam is used rag_token.config.num_beams = 1 output_ids = rag_token.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS) @slow def test_rag_token_generate_batch(self): # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417" ) rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", " amplitude modulation", " stefan persson", " april 20, 2018", " the 1970s", " 7.1. 2", " 13", ] # Split into 2 batches of 4 examples to avoid GPU OOM. output_ids = rag_token.generate( input_ids[:4], attention_mask=attention_mask[:4], ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4]) output_ids = rag_token.generate( input_ids[4:], attention_mask=attention_mask[4:], ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:]) @slow def test_rag_sequence_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417", ) rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask output_ids = rag_sequence.generate( input_ids, attention_mask=attention_mask, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS) @slow def test_rag_sequence_generate_batch_from_context_input_ids(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417" ) rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf", padding=True, truncation=True, ) input_ids = input_dict.input_ids question_hidden_states = rag_sequence.question_encoder(input_ids)[0] docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") doc_scores = tf.squeeze( tf.matmul( tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True ), axis=[1], ) output_ids = rag_sequence.generate( context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, do_deduplication=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) EXPECTED_OUTPUTS = [ " albert einstein", " june 22, 2018", " amplitude modulation", " tim besley ( chairman )", " june 20, 2018", " 1980", " 7.0", " 8", ] self.assertListEqual(outputs, EXPECTED_OUTPUTS) @require_tf @require_retrieval class TFRagModelSaveLoadTests(unittest.TestCase): def get_rag_config(self): question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") return RagConfig.from_question_encoder_generator_configs( question_encoder_config, generator_config, bos_token_id=0, decoder_start_token_id=2, eos_token_id=2, is_encoder_decoder=True, pad_token_id=1, vocab_size=50264, title_sep=" / ", doc_sep=" // ", n_docs=5, max_combined_length=300, dataset="wiki_dpr", dataset_split="train", index_name="exact", index_path=None, use_dummy_dataset=True, retrieval_vector_size=768, retrieval_batch_size=8, dataset_revision="b24a417", ) @slow def test_rag_sequence_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) rag_sequence.build_in_name_scope() # check that the from pretrained methods work rag_sequence.save_pretrained(tmp_dirname) rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_sequence question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator" ) rag_sequence = TFRagSequenceForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever ) output = rag_sequence(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4) @slow def test_rag_token_from_pretrained(self): load_weight_prefix = "tf_rag_model_1" rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf" ).input_ids decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids with tempfile.TemporaryDirectory() as tmp_dirname: rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator( "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, ) rag_token.build_in_name_scope() # check that the from pretrained methods work rag_token.save_pretrained(tmp_dirname) rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) output = rag_token(input_ids, labels=decoder_input_ids) loss_pretrained = output.loss del rag_token question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator = TFAutoModelForSeq2SeqLM.from_pretrained( "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator" ) rag_token = TFRagTokenForGeneration( config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever ) output = rag_token(input_ids, labels=decoder_input_ids) loss_init = output.loss self.assertAlmostEqual(loss_pretrained, loss_init, places=4)