diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 035787a97b2..cf121eee416 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel): self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.attentions[int(layer)].n_heads == config.n_heads: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 7ed1eddbfba..dbb041ab054 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -269,7 +269,6 @@ class CommonTestCases: shutil.rmtree(directory) def test_head_pruning_save_load_from_config_init(self): - print(self.test_pruning) if not self.test_pruning: return