mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Also fix loss function issue with the double head models
This commit is contained in:
parent
472857c47f
commit
0dd796e359
@ -698,8 +698,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
|
shift_logits = lm_logits[:, :-1]
|
||||||
|
shift_labels = lm_labels[:, 1:]
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
losses.append(loss_fct(shift_logits.view(-1,
|
||||||
|
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||||
if mc_labels is not None:
|
if mc_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
|
@ -811,8 +811,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
|
shift_logits = lm_logits[:, :-1]
|
||||||
|
shift_labels = lm_labels[:, 1:]
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
losses.append(loss_fct(shift_logits.view(-1,
|
||||||
|
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||||
if mc_labels is not None:
|
if mc_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
|
Loading…
Reference in New Issue
Block a user