diff --git a/examples/run_openai_gpt.py b/examples/run_openai_gpt.py index 6e0a0abf0cb..7a434ceacaf 100644 --- a/examples/run_openai_gpt.py +++ b/examples/run_openai_gpt.py @@ -64,7 +64,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d for dataset in encoded_datasets: n_batch = len(dataset) input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64) - mc_token_mask = np.zeros((n_batch, 2, input_len), dtype=np.int64) + mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64) lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64) mc_labels = np.zeros((n_batch,), dtype=np.int64) for i, (story, cont1, cont2, mc_label), in enumerate(dataset): @@ -72,12 +72,12 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 1, :len(with_cont2)] = with_cont2 - mc_token_mask[i, 0, len(with_cont1) - 1] = 1 - mc_token_mask[i, 1, len(with_cont2) - 1] = 1 + mc_token_ids[i, 0] = len(with_cont1) - 1 + mc_token_ids[i, 1] = len(with_cont2) - 1 lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:] lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:] mc_labels[i] = mc_label - all_inputs = (input_ids, mc_token_mask, lm_labels, mc_labels) + all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) return tensor_datasets @@ -197,8 +197,8 @@ def main(): tqdm_bar = tqdm(train_dataloader, desc="Training") for step, batch in enumerate(tqdm_bar): batch = tuple(t.to(device) for t in batch) - input_ids, mc_token_mask, lm_labels, mc_labels = batch - losses = model(input_ids, mc_token_mask, lm_labels, mc_labels) + input_ids, mc_token_ids, lm_labels, mc_labels = batch + losses = model(input_ids, mc_token_ids, lm_labels, mc_labels) loss = args.lm_coef * losses[0] + losses[1] loss.backward() optimizer.step() @@ -226,10 +226,10 @@ def main(): nb_eval_steps, nb_eval_examples = 0, 0 for batch in tqdm(eval_dataloader, desc="Evaluating"): batch = tuple(t.to(device) for t in batch) - input_ids, mc_token_mask, lm_labels, mc_labels = batch + input_ids, mc_token_ids, lm_labels, mc_labels = batch with torch.no_grad(): - _, mc_loss = model(input_ids, mc_token_mask, lm_labels, mc_labels) - _, mc_logits = model(input_ids, mc_token_mask) + _, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels) + _, mc_logits = model(input_ids, mc_token_ids) mc_logits = mc_logits.detach().cpu().numpy() mc_labels = mc_labels.to('cpu').numpy() diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index e6f3fc4efe7..60bf546c8c6 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -366,23 +366,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.bias, 0) - def forward(self, hidden_states, mc_token_mask): + def forward(self, hidden_states, mc_token_ids): # Classification logits - # hidden_states = hidden_states.view(-1, self.n_embd) - # mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states) - mc_token_mask = mc_token_mask.float() - multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1) - multiple_choice_h = multiple_choice_h.sum(dim=-2) - # flat = x[..., 0].contiguous().view(-1) - # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] - # multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1) - # # This double transposition is there to replicate the behavior - # # of the noise_shape argument in the tensorflow - # # implementation. For more details, see - # # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11 - # multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) - # multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd) + # hidden_state (bsz, num_choices, seq_length, hidden_size) + # mc_token_ids (bsz, num_choices) + mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) + # (bsz, num_choices, 1, hidden_size) + multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) + # (bsz, num_choices, hidden_size) multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) + # (bsz, num_choices) return multiple_choice_logits @@ -727,7 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): - """OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training"). + """OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training"). OpenAI GPT use a single embedding matrix to store the word and special embeddings. Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]... @@ -750,8 +743,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): config: a OpenAIGPTConfig class instance with the configuration to build a new model Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] - were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[ + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token + indices selected in the range [0, total_tokens_embeddings[ + `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from + which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence) `position_ids`: an optional torch.LongTensor with the same shape as input_ids with the position indices (selected in the range [0, config.n_positions - 1[. `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids @@ -775,13 +770,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): Example usage: ```python # Already been converted into BPE token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length) + mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice) config = modeling_openai.OpenAIGPTConfig() model = modeling_openai.OpenAIGPTLMHeadModel(config) - lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask) + lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids) ``` """ @@ -799,10 +794,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) - def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): + def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] if lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 81892a981ae..6baaaf677ab 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -89,11 +89,11 @@ class OpenAIGPTModelTest(unittest.TestCase): mc_labels = None lm_labels = None - mc_token_mask = None + mc_token_ids = None if self.use_labels: mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) - mc_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() + mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float() config = OpenAIGPTConfig( vocab_size_or_config_json_file=self.vocab_size, @@ -109,10 +109,10 @@ class OpenAIGPTModelTest(unittest.TestCase): initializer_range=self.initializer_range) return (config, input_ids, token_type_ids, position_ids, - mc_labels, lm_labels, mc_token_mask) + mc_labels, lm_labels, mc_token_ids) def create_openai_model(self, config, input_ids, token_type_ids, position_ids, - mc_labels, lm_labels, mc_token_mask): + mc_labels, lm_labels, mc_token_ids): model = OpenAIGPTModel(config) model.eval() hidden_states = model(input_ids, position_ids, token_type_ids) @@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase): def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids, - mc_labels, lm_labels, mc_token_mask): + mc_labels, lm_labels, mc_token_ids): model = OpenAIGPTLMHeadModel(config) model.eval() loss = model(input_ids, position_ids, token_type_ids, lm_labels) @@ -151,13 +151,13 @@ class OpenAIGPTModelTest(unittest.TestCase): []) def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, - mc_labels, lm_labels, mc_token_mask): + mc_labels, lm_labels, mc_token_ids): model = OpenAIGPTDoubleHeadsModel(config) model.eval() - loss = model(input_ids, mc_token_mask, + loss = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels, token_type_ids=token_type_ids, position_ids=position_ids) - lm_logits, mc_logits = model(input_ids, mc_token_mask, position_ids=position_ids, token_type_ids=token_type_ids) + lm_logits, mc_logits = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids) outputs = { "loss": loss, "lm_logits": lm_logits,