mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Cast logits to fp32 at the end of TF_T5 (#12332)
This change enables tf.keras.mixed_precision with bf16
This commit is contained in:
parent
b7439675b8
commit
f064e0a43d
@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
else:
|
||||
logits = self.lm_head(sequence_output)
|
||||
|
||||
logits = tf.cast(logits, tf.float32)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
|
Loading…
Reference in New Issue
Block a user