Mask t5 relative position bias then head pruned (#17968)

* add position bias head masking if heads pruned

* fix pruning function in t5 encoder

* make style

* make fix-copies

* Revert added folder

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Had 2022-09-06 08:39:31 +00:00 committed by GitHub
parent d4dbd7ca59
commit 734b7e2a5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 3 deletions

View File

@ -518,7 +518,14 @@ class LongT5Attention(nn.Module):
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)

View File

@ -528,7 +528,14 @@ class T5Attention(nn.Module):
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
@ -1802,7 +1809,7 @@ class T5EncoderModel(T5PreTrainedModel):
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)