mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Mask t5 relative position bias then head pruned (#17968)
* add position bias head masking if heads pruned * fix pruning function in t5 encoder * make style * make fix-copies * Revert added folder Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
d4dbd7ca59
commit
734b7e2a5a
@ -518,7 +518,14 @@ class LongT5Attention(nn.Module):
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
if self.pruned_heads:
|
||||
mask = torch.ones(position_bias.shape[1])
|
||||
mask[list(self.pruned_heads)] = 0
|
||||
position_bias_masked = position_bias[:, mask.bool()]
|
||||
else:
|
||||
position_bias_masked = position_bias
|
||||
|
||||
scores += position_bias_masked
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
||||
scores
|
||||
) # (batch_size, n_heads, seq_length, key_length)
|
||||
|
@ -528,7 +528,14 @@ class T5Attention(nn.Module):
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
if self.pruned_heads:
|
||||
mask = torch.ones(position_bias.shape[1])
|
||||
mask[list(self.pruned_heads)] = 0
|
||||
position_bias_masked = position_bias[:, mask.bool()]
|
||||
else:
|
||||
position_bias_masked = position_bias
|
||||
|
||||
scores += position_bias_masked
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
||||
scores
|
||||
) # (batch_size, n_heads, seq_length, key_length)
|
||||
@ -1802,7 +1809,7 @@ class T5EncoderModel(T5PreTrainedModel):
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
Loading…
Reference in New Issue
Block a user