mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
XLM can be pruned
This commit is contained in:
parent
42e00cf9e1
commit
fc1fbae45d
@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel):
|
||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
||||
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
for layer, heads in pruned_heads:
|
||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
|
@ -269,7 +269,6 @@ class CommonTestCases:
|
||||
shutil.rmtree(directory)
|
||||
|
||||
def test_head_pruning_save_load_from_config_init(self):
|
||||
print(self.test_pruning)
|
||||
if not self.test_pruning:
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user