mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TF Roberta for mixed precision training (#11675)
This commit is contained in:
parent
a135f59536
commit
d9b286272c
@ -541,7 +541,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
||||
extended_attention_mask = tf.multiply(tf.subtract(1.0, extended_attention_mask), -10000.0)
|
||||
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
Loading…
Reference in New Issue
Block a user