mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Merge pull request #1077 from huggingface/pruning-save-and-load
Pruning changes so that deleted heads are kept on save/load
This commit is contained in:
commit
ff7368eb6b
@ -337,23 +337,30 @@ class BertAttention(nn.Module):
|
|||||||
super(BertAttention, self).__init__()
|
super(BertAttention, self).__init__()
|
||||||
self.self = BertSelfAttention(config)
|
self.self = BertSelfAttention(config)
|
||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
# Compute how many pruned heads are before the head and move the index accordingly
|
||||||
|
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
|
|
||||||
# Prune linear layers
|
# Prune linear layers
|
||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
self.self.key = prune_linear_layer(self.self.key, index)
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
self.self.value = prune_linear_layer(self.self.value, index)
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
# Update hyper params
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||||
@ -531,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_bert
|
load_tf_weights = load_tf_weights_in_bert
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def _init_weights(self, module):
|
||||||
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
""" Initialize the weights """
|
||||||
|
|
||||||
def init_weights(self, module):
|
|
||||||
""" Initialize the weights.
|
|
||||||
"""
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
@ -649,7 +652,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
self.encoder = BertEncoder(config)
|
self.encoder = BertEncoder(config)
|
||||||
self.pooler = BertPooler(config)
|
self.pooler = BertPooler(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
old_embeddings = self.embeddings.word_embeddings
|
||||||
@ -758,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.cls = BertPreTrainingHeads(config)
|
self.cls = BertPreTrainingHeads(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -826,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.cls = BertOnlyMLMHead(config)
|
self.cls = BertOnlyMLMHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -891,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.cls = BertOnlyNSPHead(config)
|
self.cls = BertOnlyNSPHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
@ -952,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
@ -1056,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
@ -1124,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
@ -1198,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
||||||
end_positions=None, position_ids=None, head_mask=None):
|
end_positions=None, position_ids=None, head_mask=None):
|
||||||
|
@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||||
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||||
|
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.n_heads = self.n_heads - len(heads)
|
self.n_heads = self.n_heads - len(heads)
|
||||||
self.dim = attention_head_size * self.n_heads
|
self.dim = attention_head_size * self.n_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, query, key, value, mask, head_mask = None):
|
def forward(self, query, key, value, mask, head_mask = None):
|
||||||
"""
|
"""
|
||||||
@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
def init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
self.embeddings = Embeddings(config) # Embeddings
|
self.embeddings = Embeddings(config) # Embeddings
|
||||||
self.transformer = Transformer(config) # Encoder
|
self.transformer = Transformer(config) # Encoder
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
old_embeddings = self.embeddings.word_embeddings
|
||||||
@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
self.classifier = nn.Linear(config.dim, config.num_labels)
|
self.classifier = nn.Linear(config.dim, config.num_labels)
|
||||||
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
|||||||
assert config.num_labels == 2
|
assert config.num_labels == 2
|
||||||
self.dropout = nn.Dropout(config.qa_dropout)
|
self.dropout = nn.Dropout(config.qa_dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
|
@ -233,22 +233,29 @@ class Attention(nn.Module):
|
|||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
|
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
# Compute how many pruned heads are before the head and move the index accordingly
|
||||||
|
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
||||||
|
|
||||||
# Prune conv1d layers
|
# Prune conv1d layers
|
||||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||||
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
||||||
|
|
||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||||
self.n_head = self.n_head - len(heads)
|
self.n_head = self.n_head - len(heads)
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def _attn(self, q, k, v, head_mask=None):
|
def _attn(self, q, k, v, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
@ -354,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
def init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||||
@ -453,7 +460,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||||
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
||||||
@ -584,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
self.transformer = GPT2Model(config)
|
self.transformer = GPT2Model(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -708,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
self.multiple_choice_head = SequenceSummary(config)
|
self.multiple_choice_head = SequenceSummary(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
@ -249,12 +249,15 @@ class Attention(nn.Module):
|
|||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@ -265,6 +268,7 @@ class Attention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||||
self.n_head = self.n_head - len(heads)
|
self.n_head = self.n_head - len(heads)
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def _attn(self, q, k, v, head_mask=None):
|
def _attn(self, q, k, v, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
@ -363,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def _init_weights(self, module):
|
||||||
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def init_weights(self, module):
|
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||||
@ -456,7 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||||
@ -569,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.transformer = OpenAIGPTModel(config)
|
self.transformer = OpenAIGPTModel(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -676,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
self.multiple_choice_head = SequenceSummary(config)
|
self.multiple_choice_head = SequenceSummary(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
@ -168,7 +168,7 @@ class RobertaModel(BertModel):
|
|||||||
super(RobertaModel, self).__init__(config)
|
super(RobertaModel, self).__init__(config)
|
||||||
|
|
||||||
self.embeddings = RobertaEmbeddings(config)
|
self.embeddings = RobertaEmbeddings(config)
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
||||||
if input_ids[:, 0].sum().item() != 0:
|
if input_ids[:, 0].sum().item() != 0:
|
||||||
@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
self.roberta = RobertaModel(config)
|
self.roberta = RobertaModel(config)
|
||||||
self.lm_head = RobertaLMHead(config)
|
self.lm_head = RobertaLMHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_transfo_xl
|
load_tf_weights = load_tf_weights_in_transfo_xl
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
|
||||||
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def _init_weight(self, weight):
|
def _init_weight(self, weight):
|
||||||
if self.config.init == 'uniform':
|
if self.config.init == 'uniform':
|
||||||
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
||||||
@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_bias(self, bias):
|
def _init_bias(self, bias):
|
||||||
nn.init.constant_(bias, 0.0)
|
nn.init.constant_(bias, 0.0)
|
||||||
|
|
||||||
def init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
self.r_emb = nn.Parameter(torch.FloatTensor(
|
self.r_emb = nn.Parameter(torch.FloatTensor(
|
||||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
return self.word_emb
|
return self.word_emb
|
||||||
@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||||
config.cutoffs, div_val=config.div_val)
|
config.cutoffs, div_val=config.div_val)
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
@ -104,6 +104,7 @@ class PretrainedConfig(object):
|
|||||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||||
self.torchscript = kwargs.pop('torchscript', False)
|
self.torchscript = kwargs.pop('torchscript', False)
|
||||||
|
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
""" Save a configuration object to the directory `save_directory`, so that it
|
""" Save a configuration object to the directory `save_directory`, so that it
|
||||||
@ -200,6 +201,9 @@ class PretrainedConfig(object):
|
|||||||
# Load config
|
# Load config
|
||||||
config = cls.from_json_file(resolved_config_file)
|
config = cls.from_json_file(resolved_config_file)
|
||||||
|
|
||||||
|
if hasattr(config, 'pruned_heads'):
|
||||||
|
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
||||||
|
|
||||||
# Update config with kwargs if needed
|
# Update config with kwargs if needed
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
@ -311,7 +315,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
new_embeddings.to(old_embeddings.weight.device)
|
new_embeddings.to(old_embeddings.weight.device)
|
||||||
|
|
||||||
# initialize all new embeddings (in particular added tokens)
|
# initialize all new embeddings (in particular added tokens)
|
||||||
self.init_weights(new_embeddings)
|
self._init_weights(new_embeddings)
|
||||||
|
|
||||||
# Copy word embeddings from the previous weights
|
# Copy word embeddings from the previous weights
|
||||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||||
@ -355,14 +359,30 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
""" Initialize and prunes weights if needed. """
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
# Prune heads if needed
|
||||||
|
if self.config.pruned_heads:
|
||||||
|
self.prune_heads(self.config.pruned_heads)
|
||||||
|
|
||||||
def prune_heads(self, heads_to_prune):
|
def prune_heads(self, heads_to_prune):
|
||||||
""" Prunes heads of the base model.
|
""" Prunes heads of the base model.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||||
|
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||||
"""
|
"""
|
||||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||||
|
|
||||||
|
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
||||||
|
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
||||||
|
|
||||||
base_model._prune_heads(heads_to_prune)
|
base_model._prune_heads(heads_to_prune)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
|
@ -271,13 +271,16 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.k_lin = nn.Linear(dim, dim)
|
self.k_lin = nn.Linear(dim, dim)
|
||||||
self.v_lin = nn.Linear(dim, dim)
|
self.v_lin = nn.Linear(dim, dim)
|
||||||
self.out_lin = nn.Linear(dim, dim)
|
self.out_lin = nn.Linear(dim, dim)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@ -289,6 +292,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.n_heads = self.n_heads - len(heads)
|
self.n_heads = self.n_heads - len(heads)
|
||||||
self.dim = attention_head_size * self.n_heads
|
self.dim = attention_head_size * self.n_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@ -383,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
def init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights. """
|
""" Initialize the weights. """
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
if self.config is not None and self.config.embed_init_std is not None:
|
if self.config is not None and self.config.embed_init_std is not None:
|
||||||
@ -559,7 +563,14 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
||||||
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
if hasattr(config, "pruned_heads"):
|
||||||
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
config.pruned_heads = {}
|
||||||
|
for layer, heads in pruned_heads:
|
||||||
|
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||||
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
||||||
@ -771,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.pred_layer = XLMPredLayer(config)
|
self.pred_layer = XLMPredLayer(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -833,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||||
@ -911,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.qa_outputs = SQuADHead(config)
|
self.qa_outputs = SQuADHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||||
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
||||||
|
@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_xlnet
|
load_tf_weights = load_tf_weights_in_xlnet
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def _init_weights(self, module):
|
||||||
super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def init_weights(self, module):
|
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
|
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
||||||
@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
self.end_logits = PoolerEndLogits(config)
|
self.end_logits = PoolerEndLogits(config)
|
||||||
self.answer_class = PoolerAnswerClass(config)
|
self.answer_class = PoolerAnswerClass(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
|
@ -213,12 +213,12 @@ class CommonTestCases:
|
|||||||
if not self.test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
if "head_mask" in inputs_dict:
|
|
||||||
del inputs_dict["head_mask"]
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if "head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
config.output_hidden_states = False
|
config.output_hidden_states = False
|
||||||
model = model_class(config=config)
|
model = model_class(config=config)
|
||||||
@ -237,6 +237,120 @@ class CommonTestCases:
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
def test_head_pruning_save_load_from_pretrained(self):
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if "head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0]}
|
||||||
|
model.prune_heads(heads_to_prune)
|
||||||
|
directory = "pruned_model"
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
model.save_pretrained(directory)
|
||||||
|
model = model_class.from_pretrained(directory)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
self.assertEqual(attentions[0].shape[-3], 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
|
def test_head_pruning_save_load_from_config_init(self):
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if "head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
|
||||||
|
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
|
||||||
|
-1: [0]}
|
||||||
|
config.pruned_heads = heads_to_prune
|
||||||
|
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
def test_head_pruning_integration(self):
|
||||||
|
if not self.test_pruning:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
if "head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["head_mask"]
|
||||||
|
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
|
||||||
|
heads_to_prune = {0: [0], 1: [1, 2]}
|
||||||
|
config.pruned_heads = heads_to_prune
|
||||||
|
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
directory = "pruned_model"
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
model.save_pretrained(directory)
|
||||||
|
model = model_class.from_pretrained(directory)
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
heads_to_prune = {0: [0], 2: [1, 2]}
|
||||||
|
model.prune_heads(heads_to_prune)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
attentions = outputs[-1]
|
||||||
|
|
||||||
|
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
|
||||||
|
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
|
||||||
|
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||||
|
|
||||||
|
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||||
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
Loading…
Reference in New Issue
Block a user