mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[RAG] Propagating of n_docs as parameter to all RagModel's related functions (#7891)
* Propagating n_docs as parameter to all RagModel's related functions that defaults to self.config.n_docs * Making n_docs parameter's default value to None in marginalize function * Fixing code quality issues * Handle the special case when generator is of T5PreTrainedModel instance type. T5PreTrainedModel do not have n_docs as parameter * T5PreTrainedModel do not have n_docs as parameter * Addressing review comment Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Correcting comment by addressing review comment * Adding assert statement verifying that n_docs is correctly set. n_docs should be the same for both retriever and generator. * Fixing flake8 reported issue * Correcting test datasets for rag * Using doc_scores instead of context_input_ids to check assert as in RagSequenceForGeneration context_input_ids can be null * doc_scores second dimension have number of retrieved docs * Changing assert comment * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
7e6b6fbec9
commit
0193c8290d
@ -458,6 +458,8 @@ RAG_FORWARD_INPUTS_DOCSTRING = r"""
|
|||||||
output_retrieved(:obj:`bool`, `optional`):
|
output_retrieved(:obj:`bool`, `optional`):
|
||||||
Whether or not to return the :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`,
|
Whether or not to return the :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`,
|
||||||
:obj:`context_input_ids` and :obj:`context_attention_mask`. See returned tensors for more detail.
|
:obj:`context_input_ids` and :obj:`context_attention_mask`. See returned tensors for more detail.
|
||||||
|
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||||
|
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -521,6 +523,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
output_retrieved=None,
|
output_retrieved=None,
|
||||||
|
n_docs=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -540,6 +543,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
>>> outputs = model(input_ids=input_ids)
|
>>> outputs = model(input_ids=input_ids)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -566,7 +570,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(),
|
question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(),
|
||||||
prefix=self.generator.config.prefix,
|
prefix=self.generator.config.prefix,
|
||||||
n_docs=self.config.n_docs,
|
n_docs=n_docs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
|
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
|
||||||
@ -600,12 +604,16 @@ class RagModel(RagPreTrainedModel):
|
|||||||
doc_scores is not None
|
doc_scores is not None
|
||||||
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
|
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
|
||||||
|
|
||||||
|
assert (
|
||||||
|
doc_scores.shape[1] % n_docs
|
||||||
|
) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
|
||||||
|
|
||||||
# Decoder input without context documents
|
# Decoder input without context documents
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids.repeat_interleave(self.config.n_docs, dim=0)
|
decoder_input_ids = decoder_input_ids.repeat_interleave(n_docs, dim=0)
|
||||||
|
|
||||||
if decoder_attention_mask is not None:
|
if decoder_attention_mask is not None:
|
||||||
decoder_attention_mask = decoder_attention_mask.repeat_interleave(self.config.n_docs, dim=0)
|
decoder_attention_mask = decoder_attention_mask.repeat_interleave(n_docs, dim=0)
|
||||||
|
|
||||||
gen_outputs = self.generator(
|
gen_outputs = self.generator(
|
||||||
input_ids=context_input_ids,
|
input_ids=context_input_ids,
|
||||||
@ -702,6 +710,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
exclude_bos_score=None,
|
exclude_bos_score=None,
|
||||||
reduce_loss=None,
|
reduce_loss=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
n_docs=None,
|
||||||
**kwargs # needs kwargs for generation
|
**kwargs # needs kwargs for generation
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -741,6 +750,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
>>> # 3. Forward to generator
|
>>> # 3. Forward to generator
|
||||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
||||||
"""
|
"""
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
|
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
|
||||||
reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
|
reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
|
||||||
|
|
||||||
@ -763,6 +773,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_retrieved=output_retrieved,
|
output_retrieved=output_retrieved,
|
||||||
|
n_docs=n_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
@ -774,6 +785,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
reduce_loss=reduce_loss,
|
reduce_loss=reduce_loss,
|
||||||
epsilon=self.config.label_smoothing,
|
epsilon=self.config.label_smoothing,
|
||||||
exclude_bos_score=exclude_bos_score,
|
exclude_bos_score=exclude_bos_score,
|
||||||
|
n_docs=n_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return RetrievAugLMMarginOutput(
|
return RetrievAugLMMarginOutput(
|
||||||
@ -816,6 +828,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
do_deduplication=None, # defaults to True
|
do_deduplication=None, # defaults to True
|
||||||
num_return_sequences=None, # defaults to 1
|
num_return_sequences=None, # defaults to 1
|
||||||
num_beams=None, # defaults to 1
|
num_beams=None, # defaults to 1
|
||||||
|
n_docs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -847,6 +860,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
function, where we set ``num_return_sequences`` to :obj:`num_beams`.
|
function, where we set ``num_return_sequences`` to :obj:`num_beams`.
|
||||||
num_beams (:obj:`int`, `optional`, defaults to 1):
|
num_beams (:obj:`int`, `optional`, defaults to 1):
|
||||||
Number of beams for beam search. 1 means no beam search.
|
Number of beams for beam search. 1 means no beam search.
|
||||||
|
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||||
|
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional kwargs will be passed to :meth:`~transformers.PreTrainedModel.generate`.
|
Additional kwargs will be passed to :meth:`~transformers.PreTrainedModel.generate`.
|
||||||
|
|
||||||
@ -856,6 +871,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
|
do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
|
||||||
num_doc_return_sequences = (
|
num_doc_return_sequences = (
|
||||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
@ -869,7 +885,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
prefix=self.generator.config.prefix,
|
prefix=self.generator.config.prefix,
|
||||||
n_docs=self.config.n_docs,
|
n_docs=n_docs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)["context_input_ids"]
|
)["context_input_ids"]
|
||||||
|
|
||||||
@ -880,12 +896,11 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
kwargs["num_beams"] = num_beams
|
kwargs["num_beams"] = num_beams
|
||||||
kwargs["num_return_sequences"] = num_beams
|
kwargs["num_return_sequences"] = num_beams
|
||||||
kwargs["attention_mask"] = None
|
kwargs["attention_mask"] = None
|
||||||
|
kwargs["n_docs"] = n_docs
|
||||||
|
|
||||||
for index in range(len(input_ids)):
|
for index in range(len(input_ids)):
|
||||||
# first, generate beams from documents:
|
# first, generate beams from documents:
|
||||||
generator_input_ids = context_input_ids[
|
generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
|
||||||
index * self.config.n_docs : (index + 1) * self.config.n_docs
|
|
||||||
] # (n_docs, max_len)
|
|
||||||
|
|
||||||
output_sequences = self.generator.generate(
|
output_sequences = self.generator.generate(
|
||||||
generator_input_ids,
|
generator_input_ids,
|
||||||
@ -905,12 +920,16 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
|
|
||||||
return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
|
return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
|
||||||
|
|
||||||
def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False):
|
def get_nll(
|
||||||
|
self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
|
||||||
|
):
|
||||||
# shift tokens left
|
# shift tokens left
|
||||||
target = torch.cat(
|
target = torch.cat(
|
||||||
[target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
|
[target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
|
|
||||||
# bos_token_id is None for T5
|
# bos_token_id is None for T5
|
||||||
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
|
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
|
||||||
use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
|
use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
|
||||||
@ -923,7 +942,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
||||||
|
|
||||||
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
||||||
seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1)
|
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
|
||||||
) # batch_size x n_docs x tgt_len x dim
|
) # batch_size x n_docs x tgt_len x dim
|
||||||
doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
|
doc_logprobs = torch.nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
@ -934,7 +953,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)
|
rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)
|
||||||
|
|
||||||
# calcualate loss
|
# calcualate loss
|
||||||
target = target.unsqueeze(1).unsqueeze(-1).repeat(1, self.config.n_docs, 1, 1)
|
target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
|
||||||
assert target.dim() == rag_logprobs.dim()
|
assert target.dim() == rag_logprobs.dim()
|
||||||
|
|
||||||
ll = rag_logprobs.gather(dim=-1, index=target)
|
ll = rag_logprobs.gather(dim=-1, index=target)
|
||||||
@ -1004,7 +1023,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, **kwargs
|
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"input_ids": None,
|
"input_ids": None,
|
||||||
@ -1015,6 +1034,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"do_marginalize": True,
|
"do_marginalize": True,
|
||||||
|
"n_docs": n_docs,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1053,10 +1073,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
|
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
def marginalize(self, seq_logits, doc_scores):
|
def marginalize(self, seq_logits, doc_scores, n_docs=None):
|
||||||
|
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
|
|
||||||
# RAG-token marginalization
|
# RAG-token marginalization
|
||||||
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
seq_logprobs = torch.nn.functional.log_softmax(seq_logits, dim=-1).view(
|
||||||
seq_logits.shape[0] // self.config.n_docs, self.config.n_docs, -1, seq_logits.size(-1)
|
seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
|
||||||
)
|
)
|
||||||
doc_logprobs = torch.log_softmax(doc_scores, dim=1)
|
doc_logprobs = torch.log_softmax(doc_scores, dim=1)
|
||||||
log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
|
log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
|
||||||
@ -1082,6 +1105,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
do_marginalize=None,
|
do_marginalize=None,
|
||||||
reduce_loss=None,
|
reduce_loss=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
n_docs=None,
|
||||||
**kwargs # needs kwargs for generation
|
**kwargs # needs kwargs for generation
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -1124,6 +1148,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
|
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
|
||||||
>>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
>>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
||||||
"""
|
"""
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
|
do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
|
||||||
reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
|
reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
|
||||||
|
|
||||||
@ -1146,6 +1171,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_retrieved=output_retrieved,
|
output_retrieved=output_retrieved,
|
||||||
|
n_docs=n_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
@ -1158,10 +1184,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
labels,
|
labels,
|
||||||
reduce_loss=reduce_loss,
|
reduce_loss=reduce_loss,
|
||||||
epsilon=self.config.label_smoothing,
|
epsilon=self.config.label_smoothing,
|
||||||
|
n_docs=n_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_marginalize:
|
if do_marginalize:
|
||||||
logits = self.marginalize(logits, outputs.doc_scores)
|
logits = self.marginalize(logits, outputs.doc_scores, n_docs)
|
||||||
|
|
||||||
return RetrievAugLMMarginOutput(
|
return RetrievAugLMMarginOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
@ -1203,6 +1230,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
bad_words_ids=None,
|
bad_words_ids=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
decoder_start_token_id=None,
|
decoder_start_token_id=None,
|
||||||
|
n_docs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1274,6 +1302,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
function, where we set ``num_return_sequences`` to :obj:`num_beams`.
|
function, where we set ``num_return_sequences`` to :obj:`num_beams`.
|
||||||
decoder_start_token_id (:obj:`int`, `optional`):
|
decoder_start_token_id (:obj:`int`, `optional`):
|
||||||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||||
|
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||||
|
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
|
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
|
||||||
@ -1281,6 +1311,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||||
"""
|
"""
|
||||||
# set default parameters
|
# set default parameters
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
min_length = min_length if min_length is not None else self.config.min_length
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||||
@ -1310,7 +1341,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
prefix=self.generator.config.prefix,
|
prefix=self.generator.config.prefix,
|
||||||
n_docs=self.config.n_docs,
|
n_docs=n_docs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||||
@ -1329,8 +1360,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
1
|
1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
context_input_ids.shape[0] % n_docs
|
||||||
|
) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
|
||||||
|
|
||||||
# batch_size
|
# batch_size
|
||||||
batch_size = context_input_ids.shape[0] // self.config.n_docs
|
batch_size = context_input_ids.shape[0] // n_docs
|
||||||
|
|
||||||
encoder = self.rag.generator.get_encoder()
|
encoder = self.rag.generator.get_encoder()
|
||||||
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
||||||
@ -1345,11 +1380,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
|
|
||||||
def extend_enc_output(tensor, num_beams=None):
|
def extend_enc_output(tensor, num_beams=None):
|
||||||
# split into `batch_size`, `num_beams`, `num_docs`
|
# split into `batch_size`, `num_beams`, `num_docs`
|
||||||
tensor = tensor[None, None, :].reshape((batch_size, 1, self.config.n_docs) + tensor.shape[1:])
|
tensor = tensor[None, None, :].reshape((batch_size, 1, n_docs) + tensor.shape[1:])
|
||||||
# repeat same last hidden states over `num_beams` dimension
|
# repeat same last hidden states over `num_beams` dimension
|
||||||
tensor = tensor.expand((batch_size, num_beams, self.config.n_docs) + tensor.shape[3:])
|
tensor = tensor.expand((batch_size, num_beams, n_docs) + tensor.shape[3:])
|
||||||
# merge `batch_size`, `num_beams`, `num_docs` dims again
|
# merge `batch_size`, `num_beams`, `num_docs` dims again
|
||||||
return tensor.reshape((batch_size * num_beams * self.config.n_docs,) + tensor.shape[3:])
|
return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
|
||||||
|
|
||||||
# correctly extend last_hidden_state and attention mask
|
# correctly extend last_hidden_state and attention mask
|
||||||
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
|
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
|
||||||
@ -1362,6 +1397,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
vocab_size = self.config.generator.vocab_size
|
vocab_size = self.config.generator.vocab_size
|
||||||
kwargs["doc_scores"] = doc_scores
|
kwargs["doc_scores"] = doc_scores
|
||||||
kwargs["encoder_outputs"] = encoder_outputs
|
kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
kwargs["n_docs"] = n_docs
|
||||||
|
|
||||||
# not needed. TODO(PVP): change after generate refactor
|
# not needed. TODO(PVP): change after generate refactor
|
||||||
do_sample = False
|
do_sample = False
|
||||||
@ -1431,7 +1467,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
shifted_input_ids[:, 0] = start_token_id
|
shifted_input_ids[:, 0] = start_token_id
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0):
|
def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
|
||||||
|
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||||
# shift tokens left
|
# shift tokens left
|
||||||
target = torch.cat(
|
target = torch.cat(
|
||||||
[target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
|
[target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
|
||||||
@ -1444,7 +1481,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
smooth_obj.masked_fill_(pad_mask, 0.0)
|
smooth_obj.masked_fill_(pad_mask, 0.0)
|
||||||
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
return ll.squeeze(-1), smooth_obj.squeeze(-1)
|
||||||
|
|
||||||
rag_logprobs = self.marginalize(seq_logits, doc_scores)
|
rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
|
||||||
|
|
||||||
target = target.unsqueeze(-1)
|
target = target.unsqueeze(-1)
|
||||||
assert target.dim() == rag_logprobs.dim()
|
assert target.dim() == rag_logprobs.dim()
|
||||||
|
@ -82,7 +82,7 @@ def require_retrieval(test_case):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
|
if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
|
||||||
test_case = unittest.skip("test requires PyTorch")(test_case)
|
test_case = unittest.skip("test requires PyTorch, datasets and faiss")(test_case)
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class RagTestMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
retrieval_vector_size = 32
|
retrieval_vector_size = 32
|
||||||
n_docs = 2
|
n_docs = 3
|
||||||
max_combined_length = 16
|
max_combined_length = 16
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -186,10 +186,14 @@ class RagTestMixin:
|
|||||||
def get_retriever(self, config):
|
def get_retriever(self, config):
|
||||||
dataset = Dataset.from_dict(
|
dataset = Dataset.from_dict(
|
||||||
{
|
{
|
||||||
"id": ["0", "1"],
|
"id": ["0", "1", "3"],
|
||||||
"text": ["foo", "bar"],
|
"text": ["foo", "bar", "qux"],
|
||||||
"title": ["Foo", "Bar"],
|
"title": ["Foo", "Bar", "Qux"],
|
||||||
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
"embeddings": [
|
||||||
|
np.ones(self.retrieval_vector_size),
|
||||||
|
2 * np.ones(self.retrieval_vector_size),
|
||||||
|
3 * np.ones(self.retrieval_vector_size),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
||||||
@ -315,6 +319,125 @@ class RagTestMixin:
|
|||||||
# doc scores
|
# doc scores
|
||||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||||
|
|
||||||
|
def check_model_custom_n_docs(
|
||||||
|
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs
|
||||||
|
):
|
||||||
|
self.assertIsNotNone(config.question_encoder)
|
||||||
|
self.assertIsNotNone(config.generator)
|
||||||
|
|
||||||
|
retriever = self.get_retriever(config)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
self.assertTrue(model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
|
||||||
|
out = retriever(
|
||||||
|
input_ids,
|
||||||
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
|
prefix=config.generator.prefix,
|
||||||
|
return_tensors="pt",
|
||||||
|
n_docs=n_docs,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||||
|
out["context_input_ids"],
|
||||||
|
out["context_attention_mask"],
|
||||||
|
out["retrieved_doc_embeds"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# cast
|
||||||
|
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||||
|
context_input_ids = context_input_ids.to(input_ids)
|
||||||
|
context_attention_mask = context_attention_mask.to(input_ids)
|
||||||
|
|
||||||
|
# compute doc_scores
|
||||||
|
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model(
|
||||||
|
context_input_ids=context_input_ids,
|
||||||
|
context_attention_mask=context_attention_mask,
|
||||||
|
doc_scores=doc_scores,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
n_docs=n_docs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.logits.shape,
|
||||||
|
(n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||||
|
)
|
||||||
|
# generator encoder last hidden states
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.generator_enc_last_hidden_state.shape,
|
||||||
|
(n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||||
|
)
|
||||||
|
# doc scores
|
||||||
|
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs))
|
||||||
|
|
||||||
|
def check_model_with_mismatch_n_docs_value(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
retriever_n_docs,
|
||||||
|
generator_n_docs,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.assertIsNotNone(config.question_encoder)
|
||||||
|
self.assertIsNotNone(config.generator)
|
||||||
|
|
||||||
|
retriever = self.get_retriever(config)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
self.assertTrue(model.config.is_encoder_decoder)
|
||||||
|
|
||||||
|
question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
|
||||||
|
out = retriever(
|
||||||
|
input_ids,
|
||||||
|
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
|
||||||
|
prefix=config.generator.prefix,
|
||||||
|
return_tensors="pt",
|
||||||
|
n_docs=retriever_n_docs,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
|
||||||
|
out["context_input_ids"],
|
||||||
|
out["context_attention_mask"],
|
||||||
|
out["retrieved_doc_embeds"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# cast
|
||||||
|
retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
|
||||||
|
context_input_ids = context_input_ids.to(input_ids)
|
||||||
|
context_attention_mask = context_attention_mask.to(input_ids)
|
||||||
|
|
||||||
|
# compute doc_scores
|
||||||
|
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
AssertionError,
|
||||||
|
model.__call__,
|
||||||
|
context_input_ids=context_input_ids,
|
||||||
|
context_attention_mask=context_attention_mask,
|
||||||
|
doc_scores=doc_scores,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
n_docs=generator_n_docs,
|
||||||
|
)
|
||||||
|
|
||||||
def check_model_with_encoder_outputs(
|
def check_model_with_encoder_outputs(
|
||||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||||
):
|
):
|
||||||
@ -373,6 +496,17 @@ class RagTestMixin:
|
|||||||
inputs_dict = self.config_and_inputs
|
inputs_dict = self.config_and_inputs
|
||||||
self.check_model_generate(**inputs_dict)
|
self.check_model_generate(**inputs_dict)
|
||||||
|
|
||||||
|
def test_model_with_custom_n_docs(self):
|
||||||
|
inputs_dict = self.config_and_inputs
|
||||||
|
inputs_dict["n_docs"] = 1
|
||||||
|
self.check_model_custom_n_docs(**inputs_dict)
|
||||||
|
|
||||||
|
def test_model_with_mismatch_n_docs_value(self):
|
||||||
|
inputs_dict = self.config_and_inputs
|
||||||
|
inputs_dict["retriever_n_docs"] = 3
|
||||||
|
inputs_dict["generator_n_docs"] = 2
|
||||||
|
self.check_model_with_mismatch_n_docs_value(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_retrieval
|
@require_retrieval
|
||||||
|
Loading…
Reference in New Issue
Block a user