diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index c9405493645..0a9e41266dd 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -5,7 +5,8 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering) -from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel +from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, + OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) from .optimization import BertAdam from .optimization_openai import OpenAIAdam from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 9442b1ed69b..bde481c7b15 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -267,11 +267,11 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.bias, 0) - def forward(self, hidden_states, classification_token_mask): + def forward(self, hidden_states, multiple_choice_token_mask): # Classification logits # hidden_states = hidden_states.view(-1, self.n_embd) - # classification_token_mask = classification_token_mask.view(-1, 1).expand_as(hidden_states) - multiple_choice_h = hidden_states * classification_token_mask.unsqueeze(-1) + # multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states) + multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1) multiple_choice_h = multiple_choice_h.sum(dim=-2) # flat = x[..., 0].contiguous().view(-1) # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] @@ -496,8 +496,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) lm_logits = self.lm_head(hidden_states) if lm_labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(lm_logits, lm_labels) + loss_fct = CrossEntropyLoss(ignore_index=-1) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) return loss return lm_logits @@ -515,15 +515,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.embed.weight) - def forward(self, input_ids, classification_token_mask, position_ids=None, token_type_ids=None, + def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None, lm_labels=None, multiple_choice_labels=None): - """ - input_ids as to be of shape B x C x S + """ input_ids should be of shape B x C x S lm_labels can be masked using the -1 value """ hidden_states = self.transformer(input_ids, position_ids, token_type_ids) lm_logits = self.lm_head(hidden_states) - multiple_choice_logits = self.multiple_choice_head(hidden_states, classification_token_mask) + multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask) losses = [] if lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 539fbda9e4e..0a711664435 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -22,7 +22,8 @@ import random import torch -from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel) +from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, + OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) class OpenAIGPTModelTest(unittest.TestCase): @@ -89,11 +90,11 @@ class OpenAIGPTModelTest(unittest.TestCase): multiple_choice_labels = None lm_labels = None - classification_token_mask = None + multiple_choice_token_mask = None if self.use_labels: multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) - classification_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() + multiple_choice_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() config = OpenAIGPTConfig( vocab_size_or_config_json_file=self.vocab_size, @@ -109,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase): initializer_range=self.initializer_range) return (config, input_ids, token_type_ids, position_ids, - multiple_choice_labels, lm_labels, classification_token_mask) + multiple_choice_labels, lm_labels, multiple_choice_token_mask) def create_openai_model(self, config, input_ids, token_type_ids, position_ids, - multiple_choice_labels, lm_labels, classification_token_mask): + multiple_choice_labels, lm_labels, multiple_choice_token_mask): model = OpenAIGPTModel(config) hidden_states = model(input_ids, position_ids, token_type_ids) outputs = { @@ -126,12 +127,34 @@ class OpenAIGPTModelTest(unittest.TestCase): [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) + def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids, + multiple_choice_labels, lm_labels, multiple_choice_token_mask): + model = OpenAIGPTLMHeadModel(config) + loss = model(input_ids, position_ids, token_type_ids, lm_labels) + lm_logits = model(input_ids, position_ids, token_type_ids) + outputs = { + "loss": loss, + "lm_logits": lm_logits, + } + return outputs + + def check_openai_lm_head_output(self, result): + total_voc = self.n_ctx + self.n_special + self.vocab_size + self.parent.assertListEqual( + list(result["lm_logits"].size()), + [self.batch_size, self.n_choices, self.seq_length, total_voc]) + + def check_openai_lm_head_loss_output(self, result): + self.parent.assertListEqual( + list(result["loss"].size()), + []) + def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, - multiple_choice_labels, lm_labels, classification_token_mask): + multiple_choice_labels, lm_labels, multiple_choice_token_mask): model = OpenAIGPTDoubleHeadsModel(config) - loss = model(input_ids, classification_token_mask, position_ids, + loss = model(input_ids, multiple_choice_token_mask, position_ids, token_type_ids, lm_labels, multiple_choice_labels) - lm_logits, multiple_choice_logits = model(input_ids, classification_token_mask, position_ids, token_type_ids) + lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask, position_ids, token_type_ids) outputs = { "loss": loss, "lm_logits": lm_logits, @@ -167,6 +190,10 @@ class OpenAIGPTModelTest(unittest.TestCase): output_result = tester.create_openai_model(*config_and_inputs) tester.check_openai_model_output(output_result) + output_result = tester.create_openai_lm_head(*config_and_inputs) + tester.check_openai_lm_head_output(output_result) + tester.check_openai_lm_head_loss_output(output_result) + output_result = tester.create_openai_double_heads(*config_and_inputs) tester.check_openai_double_heads_output(output_result) tester.check_openai_double_heads_loss_output(output_result)