[ConvBert] Fix #21523 (#21849)

* fix reshaping
Fixes #21523

* add test

* styling

* last fixes

* Update src/transformers/models/convbert/modeling_convbert.py

* code quallity
This commit is contained in:
Arthur 2023-03-01 11:11:04 +01:00 committed by GitHub
parent 44e3e3fb49
commit b599b19289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module):
if config.hidden_size % self.num_attention_heads != 0:
raise ValueError("hidden_size should be divisible by num_attention_heads")
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module):
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
context_layer = torch.cat([context_layer, conv_out], 2)
new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,)
# conv and context
new_context_layer_shape = context_layer.size()[:-2] + (
self.num_attention_heads * self.attention_head_size * 2,
)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
def test_reducing_attention_heads(self):
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
config.head_ratio = 4
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
@require_torch
class ConvBertModelIntegrationTest(unittest.TestCase):