From f064e0a43d05a6bb1eb81e65e700f8e0f4ab04f9 Mon Sep 17 00:00:00 2001 From: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Date: Tue, 3 Aug 2021 21:02:59 +0200 Subject: [PATCH] Cast logits to fp32 at the end of TF_T5 (#12332) This change enables tf.keras.mixed_precision with bf16 --- src/transformers/models/t5/modeling_tf_t5.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 9d5aee46258..17f5a1dd887 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling else: logits = self.lm_head(sequence_output) + logits = tf.cast(logits, tf.float32) + loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) if not inputs["return_dict"]: