mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Proposed Fix : [RagSequenceForGeneration] generate "without" input_ids (#9220)
* Create modeling_tf_dpr.py * Add TFDPR * Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot last commit accidentally deleted these 4 lines, so I recover them back * Add TFDPR * Add TFDPR * clean up some comments, add TF input-style doc string * Add TFDPR * Make return_dict=False as default * Fix return_dict bug (in .from_pretrained) * Add get_input_embeddings() * Create test_modeling_tf_dpr.py The current version is already passed all 27 tests! Please see the test run at : https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing * fix quality * delete init weights * run fix copies * fix repo consis * del config_class, load_tf_weights They shoud be 'pytorch only' * add config_class back after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion * newline after .. note:: * import tf, np (Necessary for ModelIntegrationTest) * slow_test from_pretrained with from_pt=True At the moment we don't have TF weights (since we don't have official official TF model) Previously, I did not run slow test, so I missed this bug * Add simple TFDPRModelIntegrationTest Note that this is just a test that TF and Pytorch gives approx. the same output. However, I could not test with the official DPR repo's output yet * upload correct tf model * remove position_ids as missing keys * fix RagSeq generate with context_input_ids fix RagSeq generate with context_input_ids * apply style * delete unused lines * Add test_rag_sequence_generate_batch_from_context_input_ids * Readability improved * stylying * Stylize * typos * add check_model_generate_from_context_input_ids * make style * Apply suggestions from code review * make style2 Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patrickvonplaten <patrick@huggingface.co>
This commit is contained in:
parent
2a18b70998
commit
f3a3b91d6f
@ -377,6 +377,7 @@ RAG_START_DOCSTRING = r"""
|
|||||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||||
general usage and behavior.
|
general usage and behavior.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (:class:`~transformers.RagConfig`):
|
config (:class:`~transformers.RagConfig`):
|
||||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||||
@ -822,6 +823,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
context_input_ids=None,
|
context_input_ids=None,
|
||||||
|
context_attention_mask=None,
|
||||||
|
doc_scores=None,
|
||||||
do_deduplication=None, # defaults to True
|
do_deduplication=None, # defaults to True
|
||||||
num_return_sequences=None, # defaults to 1
|
num_return_sequences=None, # defaults to 1
|
||||||
num_beams=None, # defaults to 1
|
num_beams=None, # defaults to 1
|
||||||
@ -846,6 +849,20 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
|
context_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
|
||||||
Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
|
Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
|
||||||
retriever.
|
retriever.
|
||||||
|
context_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size * config.n_docs, config.max_combined_length)`, `optional`, returned when `output_retrieved=True`):
|
||||||
|
Attention mask post-processed from the retrieved documents and the question encoder :obj:`input_ids` by
|
||||||
|
the retriever.
|
||||||
|
|
||||||
|
If the model is not initialized with a ``retriever`` or ``input_ids`` is not given,
|
||||||
|
:obj:`context_input_ids` and :obj:`context_attention_mask` have to be provided to the forward pass.
|
||||||
|
They are returned by :meth:`~transformers.RagRetriever.__call__`.
|
||||||
|
doc_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs)`):
|
||||||
|
Score between each retrieved document embeddings (see :obj:`retrieved_doc_embeds`) and
|
||||||
|
:obj:`question_encoder_last_hidden_state`.
|
||||||
|
|
||||||
|
If the model is not initialized with a ``retriever`` or ``input_ids`` is not given, :obj:`doc_scores`
|
||||||
|
has to be provided to the forward pass. :obj:`doc_scores` are returned by
|
||||||
|
:meth:`~transformers.RagRetriever.__call__`.
|
||||||
do_deduplication (:obj:`bool`, `optional`):
|
do_deduplication (:obj:`bool`, `optional`):
|
||||||
Whether or not to deduplicate the generations from different context documents for a given input. Has
|
Whether or not to deduplicate the generations from different context documents for a given input. Has
|
||||||
to be set to :obj:`False` if used while training with distributed backend.
|
to be set to :obj:`False` if used while training with distributed backend.
|
||||||
@ -873,6 +890,10 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
)
|
)
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids is not None or context_input_ids is not None
|
||||||
|
), " At least one of input_ids or context_input_ids must be given"
|
||||||
|
|
||||||
if self.retriever is not None and context_input_ids is None:
|
if self.retriever is not None and context_input_ids is None:
|
||||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
context_input_ids = self.retriever(
|
context_input_ids = self.retriever(
|
||||||
@ -891,7 +912,9 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
model_kwargs["num_return_sequences"] = num_beams
|
model_kwargs["num_return_sequences"] = num_beams
|
||||||
model_kwargs["attention_mask"] = None
|
model_kwargs["attention_mask"] = None
|
||||||
|
|
||||||
for index in range(len(input_ids)):
|
batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs
|
||||||
|
|
||||||
|
for index in range(batch_size):
|
||||||
# first, generate beams from documents:
|
# first, generate beams from documents:
|
||||||
generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
|
generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
|
||||||
|
|
||||||
@ -903,9 +926,40 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
# do_deduplication, max_output_len
|
# do_deduplication, max_output_len
|
||||||
output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values()))
|
output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values()))
|
||||||
|
|
||||||
|
num_candidates = output_sequences.shape[
|
||||||
|
0
|
||||||
|
] # after deduplication, this number can be less than n_docs*n_beam
|
||||||
|
|
||||||
# then, run model forwards to get nll scores:
|
# then, run model forwards to get nll scores:
|
||||||
new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
|
if input_ids is not None:
|
||||||
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
|
new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
|
||||||
|
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
|
||||||
|
else: # input_ids is None, need context_input_ids/mask and doc_scores
|
||||||
|
assert (
|
||||||
|
context_attention_mask is not None
|
||||||
|
), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
|
||||||
|
assert (
|
||||||
|
doc_scores is not None
|
||||||
|
), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
|
||||||
|
|
||||||
|
individual_input_ids = generator_input_ids.repeat(
|
||||||
|
num_candidates, 1
|
||||||
|
) # (num_candidates*n_docs, max_len)
|
||||||
|
|
||||||
|
individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs]
|
||||||
|
individual_attention_mask = individual_attention_mask.repeat(num_candidates, 1)
|
||||||
|
|
||||||
|
individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs]
|
||||||
|
individual_doc_scores = individual_doc_scores.repeat(num_candidates, 1) # [num_candidates, n_docs]
|
||||||
|
|
||||||
|
outputs = self(
|
||||||
|
context_input_ids=individual_input_ids,
|
||||||
|
context_attention_mask=individual_attention_mask,
|
||||||
|
doc_scores=individual_doc_scores,
|
||||||
|
labels=output_sequences,
|
||||||
|
exclude_bos_score=True,
|
||||||
|
)
|
||||||
|
|
||||||
top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
|
top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
|
||||||
|
|
||||||
# add hypothesis
|
# add hypothesis
|
||||||
@ -934,9 +988,10 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
smooth_obj.masked_fill_(pad_mask, 0.0)
|
smooth_obj.masked_fill_(pad_mask, 0.0)
|
||||||
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
||||||
|
|
||||||
|
# seq_logits dim = (batch*n_docs, tgt_len , #vocabs)
|
||||||
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
||||||
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
|
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
|
||||||
) # batch_size x n_docs x tgt_len x dim
|
) # batch_size x n_docs x tgt_len x #vocab_size
|
||||||
doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
|
doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
# RAG-sequence marginalization
|
# RAG-sequence marginalization
|
||||||
|
@ -246,6 +246,53 @@ class RagTestMixin:
|
|||||||
# doc scores
|
# doc scores
|
||||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||||
|
|
||||||
|
def check_model_generate_from_context_input_ids(
|
||||||
|
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||||
|
):
|
||||||
|
self.assertIsNotNone(config.question_encoder)
|
||||||
|
self.assertIsNotNone(config.generator)
|
||||||
|
|
||||||
|
retriever = self.get_retriever(config)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
self.assertTrue(model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
|
||||||
|
out = retriever(
|
||||||
|
input_ids,
|
||||||
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
|
prefix=config.generator.prefix,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||||
|
out["context_input_ids"],
|
||||||
|
out["context_attention_mask"],
|
||||||
|
out["retrieved_doc_embeds"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# cast
|
||||||
|
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||||
|
context_input_ids = context_input_ids.to(input_ids)
|
||||||
|
context_attention_mask = context_attention_mask.to(input_ids)
|
||||||
|
|
||||||
|
# compute doc_scores
|
||||||
|
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
context_input_ids=context_input_ids,
|
||||||
|
context_attention_mask=context_attention_mask,
|
||||||
|
doc_scores=doc_scores,
|
||||||
|
do_deduplication=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
def check_model_generate(
|
def check_model_generate(
|
||||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||||
):
|
):
|
||||||
@ -848,6 +895,63 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_rag_sequence_generate_batch_from_context_input_ids(self):
|
||||||
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||||
|
retriever = RagRetriever.from_pretrained(
|
||||||
|
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||||
|
)
|
||||||
|
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dict = tokenizer(
|
||||||
|
self.test_data_questions,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = input_dict.input_ids.to(torch_device)
|
||||||
|
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||||
|
|
||||||
|
question_hidden_states = rag_sequence.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
docs_dict = retriever(
|
||||||
|
input_ids.cpu().detach().numpy(), question_hidden_states.cpu().detach().numpy(), return_tensors="pt"
|
||||||
|
)
|
||||||
|
doc_scores = torch.bmm(
|
||||||
|
question_hidden_states.unsqueeze(1),
|
||||||
|
docs_dict["retrieved_doc_embeds"].to(torch_device).float().transpose(1, 2),
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
output_ids = rag_sequence.generate(
|
||||||
|
context_input_ids=docs_dict["context_input_ids"].to(torch_device),
|
||||||
|
context_attention_mask=docs_dict["context_attention_mask"].to(torch_device),
|
||||||
|
doc_scores=doc_scores.to(torch_device),
|
||||||
|
do_deduplication=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_OUTPUTS = [
|
||||||
|
" albert einstein",
|
||||||
|
" june 22, 2018",
|
||||||
|
" amplitude modulation",
|
||||||
|
" tim besley ( chairman )",
|
||||||
|
" june 20, 2018",
|
||||||
|
" 1980",
|
||||||
|
" 7.0",
|
||||||
|
" 8",
|
||||||
|
" reticular formation",
|
||||||
|
" walls of the abdomen",
|
||||||
|
" spodumene",
|
||||||
|
" obama",
|
||||||
|
" new orleans",
|
||||||
|
" japan",
|
||||||
|
" old trafford",
|
||||||
|
]
|
||||||
|
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_rag_token_generate_batch(self):
|
def test_rag_token_generate_batch(self):
|
||||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||||
|
Loading…
Reference in New Issue
Block a user