mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix error in mixed precision training of TFCvtModel
(#22267)
* Make sure CVT can be trained using mixed precision * Add test for keras-fit with mixed-precision * Update tests/models/cvt/test_modeling_tf_cvt.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: gcuder <Gerald.Cuder@iacapps.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
330d8b991f
commit
5a2b77a6c1
@ -93,7 +93,7 @@ class TFCvtDropPath(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
keep_prob = 1 - self.drop_prob
|
keep_prob = 1 - self.drop_prob
|
||||||
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
||||||
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
|
||||||
random_tensor = tf.floor(random_tensor)
|
random_tensor = tf.floor(random_tensor)
|
||||||
return (x / keep_prob) * random_tensor
|
return (x / keep_prob) * random_tensor
|
||||||
|
|
||||||
|
@ -186,6 +186,12 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
super().test_keras_fit()
|
super().test_keras_fit()
|
||||||
|
|
||||||
|
def test_keras_fit_mixed_precision(self):
|
||||||
|
policy = tf.keras.mixed_precision.Policy("mixed_float16")
|
||||||
|
tf.keras.mixed_precision.set_global_policy(policy)
|
||||||
|
super().test_keras_fit()
|
||||||
|
tf.keras.mixed_precision.set_global_policy("float32")
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user