mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)
* fix rag * fix slow test * fix past in bart
This commit is contained in:
parent
6587cf9f84
commit
fa1ddced9e
@ -16,7 +16,7 @@
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attn_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[torch.Tensor] = False,
|
||||
):
|
||||
@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
# Self Attention
|
||||
|
||||
# decoder uni-directional self-attention cached key/values tuple is at first position
|
||||
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
|
||||
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn(
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attn_mask=attn_mask,
|
||||
@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at second position
|
||||
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module):
|
||||
if not self.normalize_before:
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
if self.normalize_before:
|
||||
@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module):
|
||||
if not self.normalize_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
|
||||
present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
self_attn_weights,
|
||||
@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
||||
cross-attention of the decoder.
|
||||
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
|
||||
@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel):
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict:
|
||||
return tuple(past_state.index_select(0, new_order) for past_state in cache)
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),)
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
||||
|
||||
def _reorder_stacked(hidden_states):
|
||||
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
|
||||
def _reorder_stacked(hidden_states, new_order):
|
||||
n_docs = hidden_states.shape[0] // new_order.shape[0]
|
||||
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
|
||||
hidden_states = hidden_states.index_select(0, beam_idx)
|
||||
return hidden_states.view(-1, *hidden_states.shape[2:])
|
||||
hidden_states = hidden_states.index_select(0, new_order)
|
||||
result = hidden_states.view(-1, *hidden_states.shape[2:])
|
||||
return result
|
||||
|
||||
def _reorder_buffer(attn_cache):
|
||||
for k, input_buffer_k in attn_cache.items():
|
||||
if input_buffer_k is not None:
|
||||
attn_cache[k] = _reorder_stacked(input_buffer_k)
|
||||
return attn_cache
|
||||
|
||||
reordered_past = []
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()}
|
||||
reordered_past.append(layer_past_new)
|
||||
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
||||
|
||||
return reordered_past
|
||||
|
||||
|
@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.sequence_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
rag_sequence = self.sequence_model
|
||||
rag_sequence.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||
)
|
||||
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
" walls of the abdomen",
|
||||
" spodumene",
|
||||
" obama",
|
||||
" grainger's compound",
|
||||
" new orleans",
|
||||
" japan",
|
||||
" old trafford stadium",
|
||||
" old trafford",
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user