Fix for non-contiguous label tensors in VisonEncoderDecoder (#21582)

* add prints

* add shape

* add reshape

* clean up
This commit is contained in:
Morgan McGuire 2023-02-20 08:23:46 +00:00 committed by GitHub
parent 2840272c5f
commit 011cc17a81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -625,7 +625,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
if not return_dict:
if loss is not None: