mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
XLM can be pruned
This commit is contained in:
parent
42e00cf9e1
commit
fc1fbae45d
@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
||||||
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
||||||
|
|
||||||
|
if hasattr(config, "pruned_heads"):
|
||||||
|
pruned_heads = config.pruned_heads.copy().items()
|
||||||
|
for layer, heads in pruned_heads:
|
||||||
|
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||||
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
|
@ -269,7 +269,6 @@ class CommonTestCases:
|
|||||||
shutil.rmtree(directory)
|
shutil.rmtree(directory)
|
||||||
|
|
||||||
def test_head_pruning_save_load_from_config_init(self):
|
def test_head_pruning_save_load_from_config_init(self):
|
||||||
print(self.test_pruning)
|
|
||||||
if not self.test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user