mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[ctrl] fix pruning of MultiHeadAttention (#4904)
This commit is contained in:
parent
4e10acb3e5
commit
5d63ca6c38
@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import Conv1D, PreTrainedModel
|
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -95,6 +95,24 @@ class MultiHeadAttention(torch.nn.Module):
|
|||||||
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
|
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
|
||||||
|
|
||||||
self.dense = torch.nn.Linear(d_model_size, d_model_size)
|
self.dense = torch.nn.Linear(d_model_size, d_model_size)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
attention_head_size = self.d_model_size // self.num_heads
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.Wq = prune_linear_layer(self.Wq, index)
|
||||||
|
self.Wk = prune_linear_layer(self.Wk, index)
|
||||||
|
self.Wv = prune_linear_layer(self.Wv, index)
|
||||||
|
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params
|
||||||
|
self.num_heads = self.num_heads - len(heads)
|
||||||
|
self.d_model_size = attention_head_size * self.num_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def split_into_heads(self, x, batch_size):
|
def split_into_heads(self, x, batch_size):
|
||||||
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
||||||
@ -306,7 +324,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||||
"""
|
"""
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.h[layer].attn.prune_heads(heads)
|
self.h[layer].multi_head_attention.prune_heads(heads)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -32,7 +32,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = True
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
Loading…
Reference in New Issue
Block a user