mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[cleanup] consolidate some prune_heads logic (#4799)
This commit is contained in:
parent
4c7f564f9a
commit
a139d1a160
@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from .configuration_albert import AlbertConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.query = prune_linear_layer(self.query, index)
|
||||
|
@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from .activations import gelu, gelu_new, swish
|
||||
from .configuration_bert import BertConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -284,14 +284,9 @@ class BertAttention(nn.Module):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
|
@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from .activations import gelu
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_heads, attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
|
||||
# Prune linear layers
|
||||
self.q_lin = prune_linear_layer(self.q_lin, index)
|
||||
self.k_lin = prune_linear_layer(self.k_lin, index)
|
||||
|
@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
|
||||
from .activations import ACT2FN
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
from .modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_conv1d_layer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -122,14 +128,9 @@ class Attention(nn.Module):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
||||
)
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||
|
||||
# Prune conv1d layers
|
||||
|
@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
|
||||
from .activations import gelu_new, swish
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
from .modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_conv1d_layer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -142,13 +148,9 @@ class Attention(nn.Module):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||
heads = set(heads) - self.pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
||||
)
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||
# Prune conv1d layers
|
||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||
|
@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .configuration_t5 import T5Config
|
||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -216,13 +216,7 @@ class T5Attention(nn.Module):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_heads, self.d_kv)
|
||||
heads = set(heads) - self.pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
|
||||
# Prune linear layers
|
||||
self.q = prune_linear_layer(self.q, index)
|
||||
self.k = prune_linear_layer(self.k, index)
|
||||
|
@ -55,6 +55,20 @@ except ImportError:
|
||||
return input
|
||||
|
||||
|
||||
def find_pruneable_heads_and_indices(
|
||||
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
|
||||
) -> Tuple[set, "torch.LongTensor"]:
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||
|
@ -29,7 +29,13 @@ from torch.nn import functional as F
|
||||
from .activations import gelu
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer
|
||||
from .modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
SQuADHead,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_heads, attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
|
||||
# Prune linear layers
|
||||
self.q_lin = prune_linear_layer(self.q_lin, index)
|
||||
self.k_lin = prune_linear_layer(self.k_lin, index)
|
||||
|
Loading…
Reference in New Issue
Block a user