diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index bce23cf48fa..1e16c4c56d4 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import Conv1D, PreTrainedModel +from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer logger = logging.getLogger(__name__) @@ -95,6 +95,24 @@ class MultiHeadAttention(torch.nn.Module): self.Wv = torch.nn.Linear(d_model_size, d_model_size) self.dense = torch.nn.Linear(d_model_size, d_model_size) + self.pruned_heads = set() + + def prune_heads(self, heads): + attention_head_size = self.d_model_size // self.num_heads + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads) + + # Prune linear layers + self.Wq = prune_linear_layer(self.Wq, index) + self.Wk = prune_linear_layer(self.Wk, index) + self.Wv = prune_linear_layer(self.Wv, index) + self.dense = prune_linear_layer(self.dense, index, dim=1) + + # Update hyper params + self.num_heads = self.num_heads - len(heads) + self.d_model_size = attention_head_size * self.num_heads + self.pruned_heads = self.pruned_heads.union(heads) def split_into_heads(self, x, batch_size): x = x.reshape(batch_size, -1, self.num_heads, self.depth) @@ -306,7 +324,7 @@ class CTRLModel(CTRLPreTrainedModel): heads_to_prune: dict of {layer_num: list of heads to prune in this layer} """ for layer, heads in heads_to_prune.items(): - self.h[layer].attn.prune_heads(heads) + self.h[layer].multi_head_attention.prune_heads(heads) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) def forward( diff --git a/tests/test_modeling_ctrl.py b/tests/test_modeling_ctrl.py index b825fc90dce..ede102b9225 100644 --- a/tests/test_modeling_ctrl.py +++ b/tests/test_modeling_ctrl.py @@ -32,7 +32,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () - test_pruning = False + test_pruning = True test_torchscript = False test_resize_embeddings = False test_head_masking = False