[cleanup] consolidate some prune_heads logic (#4799)

This commit is contained in:
Sam Shleifer 2020-06-08 17:08:04 -04:00 committed by GitHub
parent 4c7f564f9a
commit a139d1a160
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 54 additions and 59 deletions

View File

@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
from .modeling_utils import PreTrainedModel
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
logger = logging.getLogger(__name__)
@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.query = prune_linear_layer(self.query, index)

View File

@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish
from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
@ -284,14 +284,9 @@ class BertAttention(nn.Module):
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)

View File

@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)

View File

@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
from .modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
logger = logging.getLogger(__name__)
@ -122,14 +128,9 @@ class Attention(nn.Module):
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers

View File

@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu_new, swish
from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
from .modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
logger = logging.getLogger(__name__)
@ -142,13 +148,9 @@ class Attention(nn.Module):
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)

View File

@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
@ -216,13 +216,7 @@ class T5Attention(nn.Module):
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, self.d_kv)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
# Prune linear layers
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)

View File

@ -55,6 +55,20 @@ except ImportError:
return input
def find_pruneable_heads_and_indices(
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
) -> Tuple[set, "torch.LongTensor"]:
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
class ModuleUtilsMixin:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.

View File

@ -29,7 +29,13 @@ from torch.nn import functional as F
from .activations import gelu
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
SequenceSummary,
SQuADHead,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__)
@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)