From 5a2b77a6c1dc54372d0569c6a681c69895eab904 Mon Sep 17 00:00:00 2001 From: Gerald Cuder <60609608+gcuder@users.noreply.github.com> Date: Tue, 21 Mar 2023 13:12:57 +0100 Subject: [PATCH] Fix error in mixed precision training of `TFCvtModel` (#22267) * Make sure CVT can be trained using mixed precision * Add test for keras-fit with mixed-precision * Update tests/models/cvt/test_modeling_tf_cvt.py Co-authored-by: Matt --------- Co-authored-by: gcuder Co-authored-by: Matt --- src/transformers/models/cvt/modeling_tf_cvt.py | 2 +- tests/models/cvt/test_modeling_tf_cvt.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 52cc6585a7a..6ad86071e47 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -93,7 +93,7 @@ class TFCvtDropPath(tf.keras.layers.Layer): return x keep_prob = 1 - self.drop_prob shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype) random_tensor = tf.floor(random_tensor) return (x / keep_prob) * random_tensor diff --git a/tests/models/cvt/test_modeling_tf_cvt.py b/tests/models/cvt/test_modeling_tf_cvt.py index 4605f6782bd..484bd295d17 100644 --- a/tests/models/cvt/test_modeling_tf_cvt.py +++ b/tests/models/cvt/test_modeling_tf_cvt.py @@ -186,6 +186,12 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase) def test_keras_fit(self): super().test_keras_fit() + def test_keras_fit_mixed_precision(self): + policy = tf.keras.mixed_precision.Policy("mixed_float16") + tf.keras.mixed_precision.set_global_policy(policy) + super().test_keras_fit() + tf.keras.mixed_precision.set_global_policy("float32") + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()