[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)

* fix rag

* fix slow test

* fix past in bart
This commit is contained in:
Patrick von Platen 2020-12-14 12:32:26 +01:00 committed by GitHub
parent 6587cf9f84
commit fa1ddced9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 39 deletions

View File

@ -16,7 +16,7 @@
import math
import random
import warnings
from typing import Dict, Optional, Tuple
from typing import Optional, Tuple
import numpy as np
import torch
@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False,
):
@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at first position
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn(
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attn_mask=attn_mask,
@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module):
if self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at second position
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module):
if not self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
if self.normalize_before:
@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module):
if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)
return (
hidden_states,
self_attn_weights,
@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r"""
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel):
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict:
return tuple(past_state.index_select(0, new_order) for past_state in cache)
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),)
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs=None,
**kwargs
):
if past is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states):
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
hidden_states = hidden_states.index_select(0, beam_idx)
return hidden_states.view(-1, *hidden_states.shape[2:])
hidden_states = hidden_states.index_select(0, new_order)
result = hidden_states.view(-1, *hidden_states.shape[2:])
return result
def _reorder_buffer(attn_cache):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = _reorder_stacked(input_buffer_k)
return attn_cache
reordered_past = []
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()}
reordered_past.append(layer_past_new)
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.sequence_model
rag_token.set_retriever(rag_retriever)
rag_sequence = self.sequence_model
rag_sequence.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate(
output_ids = rag_sequence.generate(
input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=2,
num_return_sequences=2,
)
@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
)
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device
)
@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
" walls of the abdomen",
" spodumene",
" obama",
" grainger's compound",
" new orleans",
" japan",
" old trafford stadium",
" old trafford",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)