BERT can be exported to TorchScript

This commit is contained in:
LysandreJik 2019-07-02 17:23:18 -04:00
parent 6ce1ee04fc
commit e891bb43d5

View File

@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = [context_layer, attention_probs] if self.output_attentions else [context_layer]
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs
@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = [attention_output] + self_outputs[1:] # add attentions if we output them
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None):
all_hidden_states = []
all_attentions = []
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states += (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions.append(layer_outputs[1])
all_attentions += (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states += (hidden_states,)
outputs = [hidden_states]
outputs = (hidden_states,)
if self.output_hidden_states:
outputs.append(all_hidden_states)
outputs += (all_hidden_states,)
if self.output_attentions:
outputs.append(all_attentions)
outputs += (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
@ -490,7 +490,7 @@ class BertLMPredictionHead(nn.Module):
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
@ -666,7 +666,7 @@ class BertModel(BertPreTrainedModel):
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = [sequence_output, pooled_output] + encoder_outputs[1:] # add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@ -739,14 +739,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = [prediction_scores, seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
outputs = [total_loss] + outputs
outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@ -815,11 +815,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = [prediction_scores] + outputs[2:] # Add hidden states and attention is they are here
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = [masked_lm_loss] + outputs
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
@ -885,11 +885,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score = self.cls(pooled_output)
outputs = [seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = [next_sentence_loss] + outputs
outputs = (next_sentence_loss,) + outputs
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
@ -960,7 +960,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
if self.num_labels == 1:
@ -970,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@ -1043,12 +1043,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = [loss] + outputs
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@ -1119,7 +1119,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1130,7 +1130,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@ -1205,7 +1205,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + outputs[2:]
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1221,6 +1221,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)