mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Roformer] Fixing a bug in RoFormerEncoder where it was ignoring the length of past_key_values when generating as a decoder (#22416)
* fix RoFormerEncoder postion embedding when generate as decoder * make fixup * add test case for check generate with past key values * remove duplicating code
This commit is contained in:
parent
1905384fd5
commit
ad5e9b6c6a
@ -259,11 +259,6 @@ class RoFormerSelfAttention(nn.Module):
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
@ -276,6 +271,9 @@ class RoFormerSelfAttention(nn.Module):
|
||||
query_layer, key_layer = self.apply_rotary_position_embeddings(
|
||||
sinusoidal_pos, query_layer, key_layer
|
||||
)
|
||||
if past_key_value is not None:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
@ -566,8 +564,10 @@ class RoFormerEncoder(nn.Module):
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
# [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
|
||||
sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[None, None, :, :]
|
||||
sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :]
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
|
@ -220,6 +220,25 @@ class RoFormerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_generate_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
model = RoFormerForCausalLM(config=config).to(torch_device).eval()
|
||||
torch.manual_seed(0)
|
||||
output_without_past_cache = model.generate(
|
||||
input_ids[:1], num_beams=2, max_length=15, do_sample=True, use_cache=False
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=15, do_sample=True)
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@ -405,6 +424,10 @@ class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_generate_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_generate_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user