XLM can be pruned

This commit is contained in:
LysandreJik 2019-08-21 18:57:30 -04:00
parent 42e00cf9e1
commit fc1fbae45d
2 changed files with 6 additions and 1 deletions

View File

@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):

View File

@ -269,7 +269,6 @@ class CommonTestCases:
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self):
print(self.test_pruning)
if not self.test_pruning:
return