Fix GPT language model loss here as well

This commit is contained in:
Catalin Voss 2019-03-24 13:36:46 -07:00
parent 5938f31fa7
commit 2e6f5ffb96

View File

@ -716,8 +716,16 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
shift_labels.view(-1))
return loss
return lm_logits