Remove DT_DOUBLE from the T5 graph (#17891)

This commit is contained in:
Michal Szutenberg 2022-06-29 11:23:49 +02:00 committed by GitHub
parent 6aae59d0b5
commit babd7b1a92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -268,7 +268,7 @@ class TFT5Attention(tf.keras.layers.Layer):
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(relative_position / max_exact)
tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32))
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
dtype=relative_position.dtype,