TF: Add missing cast to GPT-J (#18201)

* Fix TF GPT-J tests

* add try/finally block
This commit is contained in:
Joao Gante 2022-07-19 15:58:42 +01:00 committed by GitHub
parent 05ed569c79
commit ec6cd7633f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 10 deletions

View File

@ -222,7 +222,7 @@ class TFGPTJAttention(tf.keras.layers.Layer):
key = self._split_heads(key, True)
value = self._split_heads(value, False)
sincos = tf.gather(self.embed_positions, position_ids, axis=0)
sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype)
sincos = tf.split(sincos, 2, axis=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]

View File

@ -274,16 +274,17 @@ class TFCoreModelTesterMixin:
def test_mixed_precision(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# try/finally block to ensure subsequent tests run in float32
try:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
self.assertIsNotNone(outputs)
tf.keras.mixed_precision.set_global_policy("float32")
self.assertIsNotNone(outputs)
finally:
tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):