From 42e00cf9e1969973a563db2900ed86bbf58dbc71 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 19 Aug 2019 22:43:02 -0400 Subject: [PATCH 01/11] Pruning saved to configuration first try --- pytorch_transformers/modeling_bert.py | 6 ++ pytorch_transformers/modeling_utils.py | 10 ++++ .../tests/modeling_common_test.py | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index f918afff3ea..4a68c2b96b6 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -649,6 +649,12 @@ class BertModel(BertPreTrainedModel): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 0d4fce67f0c..351fbfd0e14 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -104,6 +104,7 @@ class PretrainedConfig(object): self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) + self.pruned_heads = kwargs.pop('pruned_heads', {}) def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it @@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module): heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + + for layer, heads in heads_to_prune.items(): + if str(layer) not in self.config.pruned_heads: + self.config.pruned_heads[str(layer)] = heads + else: + for head in heads: + if head not in self.config.pruned_heads[str(layer)]: + self.config.pruned_heads[str(layer)].append(head) + base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 8b9a2ffd170..7ed1eddbfba 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -219,6 +219,7 @@ class CommonTestCases: del inputs_dict["head_mask"] for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -237,6 +238,61 @@ class CommonTestCases: self.assertEqual( attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + def test_head_pruning_save_load_from_pretrained(self): + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_attentions = True + config.output_hidden_states = False + model = model_class(config=config) + model.eval() + heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0]} + model.prune_heads(heads_to_prune) + directory = "pruned_model" + if not os.path.exists(directory): + os.makedirs(directory) + model.save_pretrained(directory) + model = model_class.from_pretrained(directory) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + self.assertEqual( + attentions[0].shape[-3], 1) + self.assertEqual( + attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual( + attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + shutil.rmtree(directory) + + def test_head_pruning_save_load_from_config_init(self): + print(self.test_pruning) + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_attentions = True + config.output_hidden_states = False + + heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), + -1: [0]} + config.pruned_heads = heads_to_prune + + model = model_class(config=config) + model.eval() + + outputs = model(**inputs_dict) + attentions = outputs[-1] + self.assertEqual( + attentions[0].shape[-3], 1) + self.assertEqual( + attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual( + attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From fc1fbae45df552eb4ff5220463cbde11cfa2b71e Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 18:57:30 -0400 Subject: [PATCH 02/11] XLM can be pruned --- pytorch_transformers/modeling_xlm.py | 6 ++++++ pytorch_transformers/tests/modeling_common_test.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 035787a97b2..cf121eee416 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel): self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.attentions[int(layer)].n_heads == config.n_heads: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 7ed1eddbfba..dbb041ab054 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -269,7 +269,6 @@ class CommonTestCases: shutil.rmtree(directory) def test_head_pruning_save_load_from_config_init(self): - print(self.test_pruning) if not self.test_pruning: return From 719cb3738d442431d246c107899b40441c3dd5ae Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 20:12:06 -0400 Subject: [PATCH 03/11] Pruning for GPT and GPT-2 --- pytorch_transformers/modeling_gpt2.py | 6 ++++++ pytorch_transformers/modeling_openai.py | 6 ++++++ .../tests/modeling_common_test.py | 17 ++++++++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 283dc68a6ae..23cc7f53132 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -453,6 +453,12 @@ class GPT2Model(GPT2PreTrainedModel): self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.h[int(layer)].attn.n_head == config.n_head: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 690aa7812be..c640b7c86c2 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -456,6 +456,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + for layer, heads in pruned_heads: + if self.h[int(layer)].attn.n_head == config.n_head: + self.prune_heads({int(layer): list(map(int, heads))}) + self.apply(self.init_weights) def _resize_token_embeddings(self, new_num_tokens): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index dbb041ab054..c06c5011530 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -213,13 +213,12 @@ class CommonTestCases: if not self.test_pruning: return - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - if "head_mask" in inputs_dict: - del inputs_dict["head_mask"] - for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -244,6 +243,10 @@ class CommonTestCases: for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False model = model_class(config=config) @@ -274,6 +277,10 @@ class CommonTestCases: for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + config.output_attentions = True config.output_hidden_states = False From 87747518e94860e730606848e6a8d2ed68ae8a51 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 21:20:39 -0400 Subject: [PATCH 04/11] Blocks deletion from already deleted heads. Necessary integration test. Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added. --- pytorch_transformers/modeling_bert.py | 1 + pytorch_transformers/modeling_gpt2.py | 1 + pytorch_transformers/modeling_openai.py | 1 + pytorch_transformers/modeling_utils.py | 21 +++-- pytorch_transformers/modeling_xlm.py | 1 + .../tests/modeling_common_test.py | 76 ++++++++++++++++--- 6 files changed, 84 insertions(+), 17 deletions(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 4a68c2b96b6..5a65e442d02 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 23cc7f53132..8aa5347c717 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.h[int(layer)].attn.n_head == config.n_head: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index c640b7c86c2..ce3768c6762 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -458,6 +458,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.h[int(layer)].attn.n_head == config.n_head: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 351fbfd0e14..0a47d07fd4e 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -201,6 +201,10 @@ class PretrainedConfig(object): # Load config config = cls.from_json_file(resolved_config_file) + if hasattr(config, 'pruned_heads'): + config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} + + # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): @@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module): """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + to_be_pruned = {} + for layer, heads in heads_to_prune.items(): - if str(layer) not in self.config.pruned_heads: - self.config.pruned_heads[str(layer)] = heads + if int(layer) not in self.config.pruned_heads: + self.config.pruned_heads[int(layer)] = heads + to_be_pruned[int(layer)] = heads else: for head in heads: - if head not in self.config.pruned_heads[str(layer)]: - self.config.pruned_heads[str(layer)].append(head) + if head not in self.config.pruned_heads[int(layer)]: + self.config.pruned_heads[int(layer)].append(head) + to_be_pruned[int(layer)].append(head) + else: + logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. " + f"The removed heads are {heads_to_prune}") - base_model._prune_heads(heads_to_prune) + base_model._prune_heads(to_be_pruned) def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index cf121eee416..1e0f8d7c772 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel): if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} for layer, heads in pruned_heads: if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index c06c5011530..8b1a70fcf3c 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -262,12 +262,9 @@ class CommonTestCases: outputs = model(**inputs_dict) attentions = outputs[-1] - self.assertEqual( - attentions[0].shape[-3], 1) - self.assertEqual( - attentions[1].shape[-3], self.model_tester.num_attention_heads) - self.assertEqual( - attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[0].shape[-3], 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) shutil.rmtree(directory) @@ -293,12 +290,67 @@ class CommonTestCases: outputs = model(**inputs_dict) attentions = outputs[-1] - self.assertEqual( - attentions[0].shape[-3], 1) - self.assertEqual( - attentions[1].shape[-3], self.model_tester.num_attention_heads) - self.assertEqual( - attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + self.assertEqual(attentions[0].shape[-3], 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + def test_head_pruning_integration(self): + if not self.test_pruning: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + + config.output_attentions = True + config.output_hidden_states = False + + heads_to_prune = {0: [0], 1: [1, 2]} + config.pruned_heads = heads_to_prune + + model = model_class(config=config) + model.eval() + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + directory = "pruned_model" + + if not os.path.exists(directory): + os.makedirs(directory) + model.save_pretrained(directory) + model = model_class.from_pretrained(directory) + shutil.rmtree(directory) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + heads_to_prune = {0: [0], 2: [1, 2]} + model.prune_heads(heads_to_prune) + + outputs = model(**inputs_dict) + attentions = outputs[-1] + + self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1) + self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2) + self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) + + self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]}) + def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 5c2b94c82aa48db997cfaf9dc63dbd520ac45609 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 21:24:48 -0400 Subject: [PATCH 05/11] Changed string so that Circle CI accepts the warning --- pytorch_transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 0a47d07fd4e..5a89badba64 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -381,8 +381,9 @@ class PreTrainedModel(nn.Module): self.config.pruned_heads[int(layer)].append(head) to_be_pruned[int(layer)].append(head) else: - logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. " - f"The removed heads are {heads_to_prune}") + logger.warning("Tried to remove head " + head + + " of layer " + layer + + " but it was already removed. The current removed heads are " + heads_to_prune) base_model._prune_heads(to_be_pruned) From c85b5db61a8825edda59a0e9f12bc1be08c63cdc Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 21 Aug 2019 21:37:30 -0400 Subject: [PATCH 06/11] Conditional append/init + fixed warning --- pytorch_transformers/modeling_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 5a89badba64..c69cba49e3e 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -379,11 +379,15 @@ class PreTrainedModel(nn.Module): for head in heads: if head not in self.config.pruned_heads[int(layer)]: self.config.pruned_heads[int(layer)].append(head) - to_be_pruned[int(layer)].append(head) + + if int(layer) in to_be_pruned: + to_be_pruned[int(layer)].append(head) + else: + to_be_pruned[int(layer)] = [head] else: - logger.warning("Tried to remove head " + head + - " of layer " + layer + - " but it was already removed. The current removed heads are " + heads_to_prune) + logger.warning("Tried to remove head " + str(head) + + " of layer " + str(layer) + + " but it was already removed. The current removed heads are " + str(heads_to_prune)) base_model._prune_heads(to_be_pruned) From 0cd283522ab46a9c1c50576be4fd309c08974d8e Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 27 Aug 2019 15:56:59 -0400 Subject: [PATCH 07/11] Attempt to fix head index --- pytorch_transformers/modeling_gpt2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 8aa5347c717..8b39ad372e3 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -233,12 +233,14 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -249,6 +251,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) + self.pruned_heads.extend(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) From 0c8e823b031d99d06bddff2b88fd4da2d7500117 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 29 Aug 2019 17:20:11 -0400 Subject: [PATCH 08/11] Added patch to remaining models --- pytorch_transformers/modeling_bert.py | 3 +++ pytorch_transformers/modeling_openai.py | 3 +++ pytorch_transformers/modeling_xlm.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 5a65e442d02..9aa25edbe3c 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -337,12 +337,14 @@ class BertAttention(nn.Module): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -354,6 +356,7 @@ class BertAttention(nn.Module): # Update hyper params self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads.extend(heads) def forward(self, input_tensor, attention_mask, head_mask=None): self_outputs = self.self(input_tensor, attention_mask, head_mask) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index ce3768c6762..78e57b0c592 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -249,12 +249,14 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = [] def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -265,6 +267,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) + self.pruned_heads.extend(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 1e0f8d7c772..17e39528f82 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module): self.k_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim) self.out_lin = nn.Linear(dim, dim) + self.pruned_heads = [] def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads @@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module): return mask = torch.ones(self.n_heads, attention_head_size) for head in heads: + head -= len(list(filter(lambda h: h < head, self.pruned_heads))) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads + self.pruned_heads.extend(heads) def forward(self, input, mask, kv=None, cache=None, head_mask=None): """ From bdb4409ed8de4d199907c75832398f2c49a564e1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 31 Aug 2019 01:59:07 +0200 Subject: [PATCH 09/11] updated pruning logic with sets - Bert and GPT-2 --- pytorch_transformers/modeling_bert.py | 43 +++++++++++--------------- pytorch_transformers/modeling_gpt2.py | 25 +++++++-------- pytorch_transformers/modeling_utils.py | 38 ++++++++++------------- 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 9aa25edbe3c..e2d83460713 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -337,26 +337,30 @@ class BertAttention(nn.Module): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() + # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - # Update hyper params + + # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, input_tensor, attention_mask, head_mask=None): self_outputs = self.self(input_tensor, attention_mask, head_mask) @@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" - def __init__(self, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): - """ Initialize the weights. - """ + def _init_weights(self, module): + """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 @@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): old_embeddings = self.embeddings.word_embeddings @@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertPreTrainingHeads(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertOnlyMLMHead(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertOnlyNSPHead(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, position_ids=None, head_mask=None): @@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.bert = BertModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, position_ids=None, head_mask=None): diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 8b39ad372e3..017ad4f7b47 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -233,25 +233,29 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) + # Prune conv1d layers self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) @@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): @@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel): self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.h[int(layer)].attn.n_head == config.n_head: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) @@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index c69cba49e3e..33bcb968b5d 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -202,8 +202,7 @@ class PretrainedConfig(object): config = cls.from_json_file(resolved_config_file) if hasattr(config, 'pruned_heads'): - config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} - + config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] @@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module): new_embeddings.to(old_embeddings.weight.device) # initialize all new embeddings (in particular added tokens) - self.init_weights(new_embeddings) + self._init_weights(new_embeddings) # Copy word embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) @@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module): return model_embeds + def init_weights(self): + """ Initialize and prunes weights if needed. """ + # Initialize weights + self.apply(self._init_weights) + + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. Arguments: heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). + E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed - to_be_pruned = {} - + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads for layer, heads in heads_to_prune.items(): - if int(layer) not in self.config.pruned_heads: - self.config.pruned_heads[int(layer)] = heads - to_be_pruned[int(layer)] = heads - else: - for head in heads: - if head not in self.config.pruned_heads[int(layer)]: - self.config.pruned_heads[int(layer)].append(head) + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON - if int(layer) in to_be_pruned: - to_be_pruned[int(layer)].append(head) - else: - to_be_pruned[int(layer)] = [head] - else: - logger.warning("Tried to remove head " + str(head) + - " of layer " + str(layer) + - " but it was already removed. The current removed heads are " + str(heads_to_prune)) - - base_model._prune_heads(to_be_pruned) + base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it From b6992b7b476fe7e231c8e144e36582fbbde0b4d4 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Sat, 31 Aug 2019 00:33:11 -0400 Subject: [PATCH 10/11] Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet --- pytorch_transformers/modeling_openai.py | 25 +++++++-------------- pytorch_transformers/modeling_roberta.py | 4 ++-- pytorch_transformers/modeling_transfo_xl.py | 9 +++----- pytorch_transformers/modeling_xlm.py | 17 +++++++------- pytorch_transformers/modeling_xlnet.py | 13 +++++------ 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 78e57b0c592..8bf9d86696c 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -249,14 +249,15 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) + heads = set(heads) - self.pruned_heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -267,7 +268,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) @@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): @@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.h[int(layer)].attn.n_head == config.n_head: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens) @@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.transformer = OpenAIGPTModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index cbd88ab86e8..6ae5cd1d440 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -168,7 +168,7 @@ class RobertaModel(BertModel): super(RobertaModel, self).__init__(config) self.embeddings = RobertaEmbeddings(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): if input_ids[:, 0].sum().item() != 0: @@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): self.roberta = RobertaModel(config) self.lm_head = RobertaLMHead(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 283fa66daf7..0c5c5b77983 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_transfo_xl base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs) - def _init_weight(self, weight): if self.config.init == 'uniform': nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) @@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel): def _init_bias(self, bias): nn.init.constant_(bias, 0.0) - def init_weights(self, m): + def _init_weights(self, m): """ Initialize the weights. """ classname = m.__class__.__name__ @@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): self.r_emb = nn.Parameter(torch.FloatTensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): return self.word_emb @@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): else: self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 17e39528f82..9eff09b362a 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module): self.k_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim) self.out_lin = nn.Linear(dim, dim) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return mask = torch.ones(self.n_heads, attention_head_size) + heads = set(heads) - self.pruned_heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, input, mask, kv=None, cache=None, head_mask=None): """ @@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: @@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel): if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens) @@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): self.transformer = XLMModel(config) self.pred_layer = XLMPredLayer(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): self.transformer = XLMModel(config) self.sequence_summary = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): @@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): self.transformer = XLMModel(config) self.qa_outputs = SQuADHead(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, start_positions=None, end_positions=None, diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index cc9c1379a1e..516e87e99ba 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_xlnet base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel): self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)]) self.dropout = nn.Dropout(config.dropout) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens) @@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): self.transformer = XLNetModel(config) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): self.sequence_summary = SequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, @@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, From 11600edc6e4e6a5ce148ca1d617c9d7e58bc7a7c Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Sat, 31 Aug 2019 00:37:41 -0400 Subject: [PATCH 11/11] Rebase on master + DistilBERT head pruning patch --- pytorch_transformers/modeling_distilbert.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_distilbert.py b/pytorch_transformers/modeling_distilbert.py index 1a0bd2496c5..d9a2f1a1770 100644 --- a/pytorch_transformers/modeling_distilbert.py +++ b/pytorch_transformers/modeling_distilbert.py @@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module): self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.pruned_heads = set() + def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return mask = torch.ones(self.n_heads, attention_head_size) + heads = set(heads) - self.pruned_heads for head in heads: + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, query, key, value, mask, head_mask = None): """ @@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, nn.Embedding): @@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): old_embeddings = self.embeddings.word_embeddings @@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_projector = nn.Linear(config.dim, config.vocab_size) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) @@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): self.classifier = nn.Linear(config.dim, config.num_labels) self.dropout = nn.Dropout(config.seq_classif_dropout) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None): distilbert_output = self.distilbert(input_ids=input_ids, @@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): assert config.num_labels == 2 self.dropout = nn.Dropout(config.qa_dropout) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None): distilbert_output = self.distilbert(input_ids=input_ids,