Cast logits to fp32 at the end of TF_T5 (#12332)

This change enables tf.keras.mixed_precision with bf16
This commit is contained in:
Michal Szutenberg 2021-08-03 21:02:59 +02:00 committed by GitHub
parent b7439675b8
commit f064e0a43d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else:
logits = self.lm_head(sequence_output)
logits = tf.cast(logits, tf.float32)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not inputs["return_dict"]: