[Roformer] Fixing a bug in RoFormerEncoder where it was ignoring the length of past_key_values when generating as a decoder (#22416)

* fix RoFormerEncoder postion embedding when generate as decoder

* make fixup

* add test case for check generate with past key values

* remove duplicating code
This commit is contained in:
TheWall9 2023-04-04 18:50:33 +08:00 committed by GitHub
parent 1905384fd5
commit ad5e9b6c6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 6 deletions

View File

@ -259,11 +259,6 @@ class RoFormerSelfAttention(nn.Module):
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
@ -276,6 +271,9 @@ class RoFormerSelfAttention(nn.Module):
query_layer, key_layer = self.apply_rotary_position_embeddings(
sinusoidal_pos, query_layer, key_layer
)
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
@ -566,8 +564,10 @@ class RoFormerEncoder(nn.Module):
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[None, None, :, :]
sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :]
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):

View File

@ -220,6 +220,25 @@ class RoFormerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_generate_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
model = RoFormerForCausalLM(config=config).to(torch_device).eval()
torch.manual_seed(0)
output_without_past_cache = model.generate(
input_ids[:1], num_beams=2, max_length=15, do_sample=True, use_cache=False
)
torch.manual_seed(0)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=15, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -405,6 +424,10 @@ class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_generate_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_generate_causal_lm(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)