From f60eec40034adad7b6caac6339ceb51a4604a847 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 15 Nov 2022 10:46:34 +0100 Subject: [PATCH] update relative positional embedding (#20203) * update relative positional embedding * make fix copies * add `use_cache` to list of arguments * fixup * 1line fucntion * add `test_decoder_model_past_with_large_inputs_relative_pos_emb` * add relative pos embedding test for more models * style --- src/transformers/models/bert/modeling_bert.py | 22 ++++++++++++++----- .../modeling_bert_generation.py | 13 ++++++++--- .../models/camembert/modeling_camembert.py | 13 ++++++++--- .../models/data2vec/modeling_data2vec_text.py | 13 ++++++++--- .../models/electra/modeling_electra.py | 13 ++++++++--- .../models/ernie/modeling_ernie.py | 22 ++++++++++++++----- .../models/layoutlm/modeling_layoutlm.py | 13 ++++++++--- .../models/markuplm/modeling_markuplm.py | 22 ++++++++++++++----- .../megatron_bert/modeling_megatron_bert.py | 13 ++++++++--- .../models/realm/modeling_realm.py | 13 ++++++++--- .../models/roberta/modeling_roberta.py | 13 ++++++++--- .../models/roc_bert/modeling_roc_bert.py | 13 ++++++++--- .../models/splinter/modeling_splinter.py | 13 ++++++++--- .../xlm_roberta/modeling_xlm_roberta.py | 13 ++++++++--- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 13 ++++++++--- tests/models/bert/test_modeling_bert.py | 5 +++++ .../data2vec/test_modeling_data2vec_text.py | 5 +++++ tests/models/ernie/test_modeling_ernie.py | 5 +++++ tests/models/roberta/test_modeling_roberta.py | 5 +++++ .../models/roc_bert/test_modeling_roc_bert.py | 5 +++++ .../test_modeling_xlm_roberta_xl.py | 5 +++++ 21 files changed, 201 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 78e8e187679..16abb6c871a 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -309,6 +309,7 @@ class BertSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -323,10 +324,16 @@ class BertSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility @@ -1267,7 +1274,7 @@ class BertLMHeadModel(BertPreTrainedModel): cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -1277,7 +1284,12 @@ class BertLMHeadModel(BertPreTrainedModel): if past is not None: input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } def _reorder_cache(self, past, beam_idx): reordered_past = () diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 2058bee9fec..237c2d2b449 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -128,6 +128,7 @@ class BertGenerationSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -142,10 +143,16 @@ class BertGenerationSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 514566596cc..9b9868d98e6 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -234,6 +234,7 @@ class CamembertSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -248,10 +249,16 @@ class CamembertSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 543e5ee367c..3ba0ab1c483 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -220,6 +220,7 @@ class Data2VecTextSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -234,10 +235,16 @@ class Data2VecTextSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 058a76acda1..06abc953f19 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -281,6 +281,7 @@ class ElectraSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -295,10 +296,16 @@ class ElectraSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index de3e539420a..95d7eac07cc 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -216,6 +216,7 @@ class ErnieSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -230,10 +231,16 @@ class ErnieSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility @@ -1207,7 +1214,7 @@ class ErnieForCausalLM(ErniePreTrainedModel): ) # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -1217,7 +1224,12 @@ class ErnieForCausalLM(ErniePreTrainedModel): if past is not None: input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache def _reorder_cache(self, past, beam_idx): diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 4535ef0e118..8ff5ff092ed 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -196,6 +196,7 @@ class LayoutLMSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -210,10 +211,16 @@ class LayoutLMSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index d5c88ab8ab8..610b35f7378 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -405,6 +405,7 @@ class MarkupLMSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -419,10 +420,16 @@ class MarkupLMSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility @@ -930,7 +937,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel): ) # Copied from transformers.models.bert.modeling_bert.BertModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -940,7 +947,12 @@ class MarkupLMModel(MarkupLMPreTrainedModel): if past is not None: input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache def _reorder_cache(self, past, beam_idx): diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index ba5460ac857..5293f1e78b9 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -259,6 +259,7 @@ class MegatronBertSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -273,10 +274,16 @@ class MegatronBertSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index c1e92147529..7f835f80074 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -296,6 +296,7 @@ class RealmSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -310,10 +311,16 @@ class RealmSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 0e0f822d415..466de1e39f3 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -220,6 +220,7 @@ class RobertaSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -234,10 +235,16 @@ class RobertaSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 1c1e029b5c9..5f611a5e084 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -331,6 +331,7 @@ class RoCBertSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -345,10 +346,16 @@ class RoCBertSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1f94f6f9ad2..914f4784146 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -169,6 +169,7 @@ class SplinterSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -183,10 +184,16 @@ class SplinterSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 53939615088..7e62eb6a050 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -221,6 +221,7 @@ class XLMRobertaSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -235,10 +236,16 @@ class XLMRobertaSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 75e4e72fa4b..0fe9db45a05 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -214,6 +214,7 @@ class XLMRobertaXLSelfAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) + use_cache = past_key_value is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -228,10 +229,16 @@ class XLMRobertaXLSelfAttention(nn.Module): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index ea49f437e05..367e5ee53c4 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -525,6 +525,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) diff --git a/tests/models/data2vec/test_modeling_data2vec_text.py b/tests/models/data2vec/test_modeling_data2vec_text.py index 631beb9e5cd..c3015c3f409 100644 --- a/tests/models/data2vec/test_modeling_data2vec_text.py +++ b/tests/models/data2vec/test_modeling_data2vec_text.py @@ -432,6 +432,11 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/ernie/test_modeling_ernie.py b/tests/models/ernie/test_modeling_ernie.py index 251900cdad2..ed0b4e1f3d4 100644 --- a/tests/models/ernie/test_modeling_ernie.py +++ b/tests/models/ernie/test_modeling_ernie.py @@ -524,6 +524,11 @@ class ErnieModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py index d53b20058b2..5128789d41a 100644 --- a/tests/models/roberta/test_modeling_roberta.py +++ b/tests/models/roberta/test_modeling_roberta.py @@ -441,6 +441,11 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/roc_bert/test_modeling_roc_bert.py b/tests/models/roc_bert/test_modeling_roc_bert.py index cf3218ac839..1f814a17b20 100644 --- a/tests/models/roc_bert/test_modeling_roc_bert.py +++ b/tests/models/roc_bert/test_modeling_roc_bert.py @@ -627,6 +627,11 @@ class RoCBertModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index 60ba66b11a5..6c9577be777 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -431,6 +431,11 @@ class XLMRobertaXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + config_and_inputs[0].position_embedding_type = "relative_key" + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)