Fix encoder-decoder models when labels is passed (#15172)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-01-26 10:14:46 +01:00 committed by GitHub
parent e79a0faeae
commit 24e2fa1590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -529,7 +529,7 @@ class EncoderDecoderModel(PreTrainedModel):
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

View File

@ -549,7 +549,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

View File

@ -503,7 +503,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))