mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* fix reshaping Fixes #21523 * add test * styling * last fixes * Update src/transformers/models/convbert/modeling_convbert.py * code quallity
This commit is contained in:
parent
44e3e3fb49
commit
b599b19289
@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module):
|
||||
if config.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError("hidden_size should be divisible by num_attention_heads")
|
||||
|
||||
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
||||
self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module):
|
||||
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
|
||||
context_layer = torch.cat([context_layer, conv_out], 2)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,)
|
||||
# conv and context
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.num_attention_heads * self.attention_head_size * 2,
|
||||
)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
result = model(inputs_embeds=inputs_embeds)
|
||||
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
||||
|
||||
def test_reducing_attention_heads(self):
|
||||
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.head_ratio = 4
|
||||
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConvBertModelIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user