From 074c869bbebd9ad1b8ec1c52ecc506ba982e8483 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 11 Apr 2019 20:53:50 +0200 Subject: [PATCH] fix OpenAIGPTMultipleChoiceHead --- pytorch_pretrained_bert/modeling_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index fb3d0cadb7a..b6252d097f3 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): def forward(self, hidden_states, mc_token_ids): # Classification logits # hidden_state (bsz, num_choices, seq_length, hidden_size) - # mc_token_ids (bsz, num_choices, 1) - mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) + # mc_token_ids (bsz, num_choices) + mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) # (bsz, num_choices, 1, hidden_size) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) # (bsz, num_choices, hidden_size)