diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 9d5aee46258..17f5a1dd887 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling else: logits = self.lm_head(sequence_output) + logits = tf.cast(logits, tf.float32) + loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) if not inputs["return_dict"]: