fix OpenAIGPTMultipleChoiceHead

This commit is contained in:
thomwolf 2019-04-11 20:53:50 +02:00
parent 724eb45cef
commit 074c869bbe

View File

@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def forward(self, hidden_states, mc_token_ids):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices, 1)
mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# mc_token_ids (bsz, num_choices)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size)