From b43b130f35d1c6e3e925762c1c06e3e53ebdea37 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 3 Jul 2019 16:21:17 -0400 Subject: [PATCH] TorchScript flag in config; Tied weights when not running TorchScript; tuple concatenation clean-up. --- pytorch_pretrained_bert/model_utils.py | 1 + pytorch_pretrained_bert/modeling.py | 18 ++++++++++++------ pytorch_pretrained_bert/modeling_gpt2.py | 16 ++++++++++------ pytorch_pretrained_bert/modeling_openai.py | 18 +++++++++++------- pytorch_pretrained_bert/modeling_xlnet.py | 14 +++++++++----- .../tests/model_tests_commons.py | 1 + 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/pytorch_pretrained_bert/model_utils.py b/pytorch_pretrained_bert/model_utils.py index 8c116df54ad..ec735c3e0af 100644 --- a/pytorch_pretrained_bert/model_utils.py +++ b/pytorch_pretrained_bert/model_utils.py @@ -46,6 +46,7 @@ class PretrainedConfig(object): self.num_labels = kwargs.pop('num_labels', 2) self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) + self.torchscript = kwargs.pop('torchscript', False) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index eb7fdf1a14c..7b18cb84527 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -428,23 +428,23 @@ class BertEncoder(nn.Module): all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) hidden_states = layer_outputs[0] if self.output_attentions: - all_attentions += (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1],) # Add last layer if self.output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) if self.output_hidden_states: - outputs += (all_hidden_states,) + outputs = outputs + (all_hidden_states,) if self.output_attentions: - outputs += (all_attentions,) + outputs = outputs + (all_attentions,) return outputs # outputs, (hidden states), (attentions) @@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) + self.torchscript = config.torchscript # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False) - self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone()) + + if self.torchscript: + self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone()) + else: + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) def forward(self, hidden_states): diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index d878cf5234e..ba4fd3e2aa8 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -322,6 +322,7 @@ class GPT2LMHead(nn.Module): self.n_embd = config.n_embd self.vocab_size = config.vocab_size self.predict_special_tokens = config.predict_special_tokens + self.torchscript = config.torchscript embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.set_embeddings_weights(model_embeddings_weights) @@ -329,7 +330,10 @@ class GPT2LMHead(nn.Module): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): self.predict_special_tokens = predict_special_tokens # Export to TorchScript can't handle parameter sharing so we are cloning them. - self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights + if self.torchscript: + self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) + else: + self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): lm_logits = self.decoder(hidden_state) @@ -563,11 +567,11 @@ class GPT2Model(GPT2PreTrainedModel): all_hidden_states = () for i, (block, layer_past) in enumerate(zip(self.h, past)): if self.output_hidden_states: - all_hidden_states += (hidden_states.view(*output_shape),) + all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) outputs = block(hidden_states, layer_past, head_mask[i]) hidden_states, present = outputs[:2] - presents += (present,) + presents = presents + (present,) if self.output_attentions: all_attentions.append(outputs[2]) @@ -577,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel): hidden_states = hidden_states.view(*output_shape) # Add last hidden state if self.output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states, presents) if self.output_hidden_states: - outputs += (all_hidden_states,) + outputs = outputs + (all_hidden_states,) if self.output_attentions: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) - outputs += (all_attentions,) + outputs = outputs + (all_attentions,) return outputs # last hidden state, presents, (all hidden_states), (attentions) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 0db4b28caf6..ed3c0c13ee2 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module): self.n_embd = config.n_embd self.vocab_size = config.vocab_size self.predict_special_tokens = config.predict_special_tokens + self.torchscript = config.torchscript embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.set_embeddings_weights(model_embeddings_weights) def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): self.predict_special_tokens = predict_special_tokens - embed_shape = model_embeddings_weights.shape - self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights + + if self.torchscript: + self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) + else: + self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): lm_logits = self.decoder(hidden_state) @@ -583,22 +587,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): all_hidden_states = () for i, block in enumerate(self.h): if self.output_hidden_states: - all_hidden_states += (hidden_states.view(*output_shape),) + all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) outputs = block(hidden_states, head_mask[i]) hidden_states = outputs[0] if self.output_attentions: - all_attentions += (outputs[1],) + all_attentions = all_attentions + (outputs[1],) # Add last layer if self.output_hidden_states: - all_hidden_states += (hidden_states.view(*output_shape),) + all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) outputs = (hidden_states.view(*output_shape),) if self.output_hidden_states: - outputs += (all_hidden_states,) + outputs = outputs + (all_hidden_states,) if self.output_attentions: - outputs += (all_attentions,) + outputs = outputs + (all_attentions,) return outputs # last hidden state, (all hidden states), (all attentions) diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index c4c3354070e..2771ba7ca5e 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -530,7 +530,7 @@ class XLNetRelativeAttention(nn.Module): outputs = (output_h, output_g) if self.output_attentions: - outputs += (attn_prob,) + outputs = outputs + (attn_prob,) return outputs class XLNetFeedForward(nn.Module): @@ -878,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel): hidden_states = [] for i, layer_module in enumerate(self.layer): # cache new mems - new_mems += (self.cache_mem(output_h, mems[i]),) + new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) if self.output_hidden_states: hidden_states.append((output_h, output_g) if output_g is not None else output_h) @@ -902,10 +902,10 @@ class XLNetModel(XLNetPreTrainedModel): hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs) else: hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states) - outputs += (hidden_states,) + outputs = outputs + (hidden_states,) if self.output_attentions: attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) - outputs += (attentions,) + outputs = outputs + (attentions,) return outputs # outputs, new_mems, (hidden_states), (attentions) @@ -975,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): super(XLNetLMHeadModel, self).__init__(config) self.attn_type = config.attn_type self.same_length = config.same_length + self.torchscript = config.torchscript self.transformer = XLNetModel(config) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) @@ -987,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def tie_weights(self): """ Make sure we are sharing the embeddings """ - self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone()) + if self.torchscript: + self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone()) + else: + self.lm_loss.weight = self.transformer.word_embedding.weight def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, diff --git a/pytorch_pretrained_bert/tests/model_tests_commons.py b/pytorch_pretrained_bert/tests/model_tests_commons.py index 0afda5f2ce9..e93cc98ffeb 100644 --- a/pytorch_pretrained_bert/tests/model_tests_commons.py +++ b/pytorch_pretrained_bert/tests/model_tests_commons.py @@ -41,6 +41,7 @@ def _create_and_check_torchscript_output_hidden_state(tester, model_classes, con def _create_and_check_torchscript(tester, model_classes, config, inputs_dict): configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True for model_class in model_classes: model = model_class(config=configs_no_init) model.eval()