mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[ctrl] fix pruning of MultiHeadAttention (#4904)
This commit is contained in:
parent
4e10acb3e5
commit
5d63ca6c38
@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -95,6 +95,24 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
|
||||
|
||||
self.dense = torch.nn.Linear(d_model_size, d_model_size)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
attention_head_size = self.d_model_size // self.num_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
|
||||
|
||||
# Prune linear layers
|
||||
self.Wq = prune_linear_layer(self.Wq, index)
|
||||
self.Wk = prune_linear_layer(self.Wk, index)
|
||||
self.Wv = prune_linear_layer(self.Wv, index)
|
||||
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
||||
|
||||
# Update hyper params
|
||||
self.num_heads = self.num_heads - len(heads)
|
||||
self.d_model_size = attention_head_size * self.num_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def split_into_heads(self, x, batch_size):
|
||||
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
||||
@ -306,7 +324,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
self.h[layer].multi_head_attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
|
@ -32,7 +32,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_pruning = True
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
Loading…
Reference in New Issue
Block a user