mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix encoder-decoder models when labels is passed (#15172)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
e79a0faeae
commit
24e2fa1590
@ -529,7 +529,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|
||||
|
||||
|
@ -549,7 +549,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|
||||
|
||||
|
@ -503,7 +503,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user