Fix error in mixed precision training of TFCvtModel (#22267)

* Make sure CVT can be trained using mixed precision

* Add test for keras-fit with mixed-precision

* Update tests/models/cvt/test_modeling_tf_cvt.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

---------

Co-authored-by: gcuder <Gerald.Cuder@iacapps.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Gerald Cuder 2023-03-21 13:12:57 +01:00 committed by GitHub
parent 330d8b991f
commit 5a2b77a6c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -93,7 +93,7 @@ class TFCvtDropPath(tf.keras.layers.Layer):
return x return x
keep_prob = 1 - self.drop_prob keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
random_tensor = tf.floor(random_tensor) random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor return (x / keep_prob) * random_tensor

View File

@ -186,6 +186,12 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_keras_fit(self): def test_keras_fit(self):
super().test_keras_fit() super().test_keras_fit()
def test_keras_fit_mixed_precision(self):
policy = tf.keras.mixed_precision.Policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy(policy)
super().test_keras_fit()
tf.keras.mixed_precision.set_global_policy("float32")
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()