diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index ae0c83fae9b..a1071408fb0 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -222,7 +222,7 @@ class TFGPTJAttention(tf.keras.layers.Layer): key = self._split_heads(key, True) value = self._split_heads(value, False) - sincos = tf.gather(self.embed_positions, position_ids, axis=0) + sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype) sincos = tf.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index d1683d69cf7..d3585d7390a 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -274,16 +274,17 @@ class TFCoreModelTesterMixin: def test_mixed_precision(self): tf.keras.mixed_precision.set_global_policy("mixed_float16") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # try/finally block to ensure subsequent tests run in float32 + try: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + outputs = model(class_inputs_dict) - for model_class in self.all_model_classes: - class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - outputs = model(class_inputs_dict) - - self.assertIsNotNone(outputs) - - tf.keras.mixed_precision.set_global_policy("float32") + self.assertIsNotNone(outputs) + finally: + tf.keras.mixed_precision.set_global_policy("float32") @slow def test_train_pipeline_custom_model(self):