diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index e2c75087736..304a05b85aa 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -377,6 +377,7 @@ RAG_START_DOCSTRING = r""" subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + Args: config (:class:`~transformers.RagConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not @@ -822,6 +823,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, context_input_ids=None, + context_attention_mask=None, + doc_scores=None, do_deduplication=None, # defaults to True num_return_sequences=None, # defaults to 1 num_beams=None, # defaults to 1 @@ -846,6 +849,20 @@ class RagSequenceForGeneration(RagPreTrainedModel): context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): Input IDs post-processed from the retrieved documents and the question encoder input_ids by the retriever. + context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`): + Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by + the retriever. + + If the model is not initialized with a ``retriever`` or ``input_ids`` is not given, + :obj:`context_input_ids` and :obj:`context_attention_mask` have to be provided to the forward pass. + They are returned by :meth:`~transformers.RagRetriever.__call__`. + doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`): + Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and + :obj:`question_encoder_last_hidden_state`. + + If the model is not initialized with a ``retriever`` or ``input_ids`` is not given, :obj:`doc_scores` + has to be provided to the forward pass. :obj:`doc_scores` are returned by + :meth:`~transformers.RagRetriever.__call__`. do_deduplication (:obj:`bool`, `optional`): Whether or not to deduplicate the generations from different context documents for a given input. Has to be set to :obj:`False` if used while training with distributed backend. @@ -873,6 +890,10 @@ class RagSequenceForGeneration(RagPreTrainedModel): ) num_beams = num_beams if num_beams is not None else self.config.num_beams + assert ( + input_ids is not None or context_input_ids is not None + ), " At least one of input_ids or context_input_ids must be given" + if self.retriever is not None and context_input_ids is None: question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] context_input_ids = self.retriever( @@ -891,7 +912,9 @@ class RagSequenceForGeneration(RagPreTrainedModel): model_kwargs["num_return_sequences"] = num_beams model_kwargs["attention_mask"] = None - for index in range(len(input_ids)): + batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs + + for index in range(batch_size): # first, generate beams from documents: generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len) @@ -903,9 +926,40 @@ class RagSequenceForGeneration(RagPreTrainedModel): # do_deduplication, max_output_len output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values())) + num_candidates = output_sequences.shape[ + 0 + ] # after deduplication, this number can be less than n_docs*n_beam + # then, run model forwards to get nll scores: - new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1) - outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True) + if input_ids is not None: + new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1) + outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True) + else: # input_ids is None, need context_input_ids/mask and doc_scores + assert ( + context_attention_mask is not None + ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + assert ( + doc_scores is not None + ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function." + + individual_input_ids = generator_input_ids.repeat( + num_candidates, 1 + ) # (num_candidates*n_docs, max_len) + + individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs] + individual_attention_mask = individual_attention_mask.repeat(num_candidates, 1) + + individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs] + individual_doc_scores = individual_doc_scores.repeat(num_candidates, 1) # [num_candidates, n_docs] + + outputs = self( + context_input_ids=individual_input_ids, + context_attention_mask=individual_attention_mask, + doc_scores=individual_doc_scores, + labels=output_sequences, + exclude_bos_score=True, + ) + top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1] # add hypothesis @@ -934,9 +988,10 @@ class RagSequenceForGeneration(RagPreTrainedModel): smooth_obj.masked_fill_(pad_mask, 0.0) return ll.squeeze(-1), smooth_obj.squeeze(-1) + # seq_logits dim = (batch*n_docs, tgt_len , #vocabs) seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view( seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1) - ) # batch_size x n_docs x tgt_len x dim + ) # batch_size x n_docs x tgt_len x #vocab_size doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1) # RAG-sequence marginalization diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index 382c59c32fa..5f69e1608de 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -246,6 +246,53 @@ class RagTestMixin: # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs)) + def check_model_generate_from_context_input_ids( + self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs + ): + self.assertIsNotNone(config.question_encoder) + self.assertIsNotNone(config.generator) + + retriever = self.get_retriever(config) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + self.assertTrue(model.config.is_encoder_decoder) + + question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0] + + out = retriever( + input_ids, + question_hidden_states.cpu().detach().to(torch.float32).numpy(), + prefix=config.generator.prefix, + return_tensors="pt", + ) + + context_input_ids, context_attention_mask, retrieved_doc_embeds = ( + out["context_input_ids"], + out["context_attention_mask"], + out["retrieved_doc_embeds"], + ) + + # cast + retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states) + context_input_ids = context_input_ids.to(input_ids) + context_attention_mask = context_attention_mask.to(input_ids) + + # compute doc_scores + doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze( + 1 + ) + + outputs = model.generate( + context_input_ids=context_input_ids, + context_attention_mask=context_attention_mask, + doc_scores=doc_scores, + do_deduplication=True, + ) + + self.assertIsNotNone(outputs) + def check_model_generate( self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs ): @@ -848,6 +895,63 @@ class RagModelIntegrationTests(unittest.TestCase): ] self.assertListEqual(outputs, EXPECTED_OUTPUTS) + @slow + def test_rag_sequence_generate_batch_from_context_input_ids(self): + tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") + retriever = RagRetriever.from_pretrained( + "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True + ) + rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( + torch_device + ) + + input_dict = tokenizer( + self.test_data_questions, + return_tensors="pt", + padding=True, + truncation=True, + ) + + input_ids = input_dict.input_ids.to(torch_device) + attention_mask = input_dict.attention_mask.to(torch_device) + + question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0] + docs_dict = retriever( + input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt" + ) + doc_scores = torch.bmm( + question_hidden_states.unsqueeze(1), + docs_dict["retrieved_doc_embeds"].to(torch_device).float().transpose(1, 2), + ).squeeze(1) + + output_ids = rag_sequence.generate( + context_input_ids=docs_dict["context_input_ids"].to(torch_device), + context_attention_mask=docs_dict["context_attention_mask"].to(torch_device), + doc_scores=doc_scores.to(torch_device), + do_deduplication=True, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + EXPECTED_OUTPUTS = [ + " albert einstein", + " june 22, 2018", + " amplitude modulation", + " tim besley ( chairman )", + " june 20, 2018", + " 1980", + " 7.0", + " 8", + " reticular formation", + " walls of the abdomen", + " spodumene", + " obama", + " new orleans", + " japan", + " old trafford", + ] + self.assertListEqual(outputs, EXPECTED_OUTPUTS) + @slow def test_rag_token_generate_batch(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")