[ctrl] fix pruning of MultiHeadAttention (#4904)

This commit is contained in:
Amil Khare 2020-06-10 23:36:55 +05:30 committed by GitHub
parent 4e10acb3e5
commit 5d63ca6c38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 3 deletions

View File

@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
@ -95,6 +95,24 @@ class MultiHeadAttention(torch.nn.Module):
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
self.dense = torch.nn.Linear(d_model_size, d_model_size)
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.d_model_size // self.num_heads
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
# Prune linear layers
self.Wq = prune_linear_layer(self.Wq, index)
self.Wk = prune_linear_layer(self.Wk, index)
self.Wv = prune_linear_layer(self.Wv, index)
self.dense = prune_linear_layer(self.dense, index, dim=1)
# Update hyper params
self.num_heads = self.num_heads - len(heads)
self.d_model_size = attention_head_size * self.num_heads
self.pruned_heads = self.pruned_heads.union(heads)
def split_into_heads(self, x, batch_size):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
@ -306,7 +324,7 @@ class CTRLModel(CTRLPreTrainedModel):
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
self.h[layer].multi_head_attention.prune_heads(heads)
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward(

View File

@ -32,7 +32,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
test_pruning = False
test_pruning = True
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False