Fix tests of mixed precision now that experimental is deprecated (#17300)

* Fix tests of mixed precision now that experimental is deprecated

* Fix mixed precision in training_args_tf.py too
This commit is contained in:
Matt 2022-05-17 14:14:17 +01:00 committed by GitHub
parent 6d211429ec
commit 651e48e1e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 6 deletions

View File

@ -195,8 +195,7 @@ class TFTrainingArguments(TrainingArguments):
# Set to float16 at first
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy("mixed_float16")
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
@ -217,8 +216,7 @@ class TFTrainingArguments(TrainingArguments):
if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

View File

@ -205,7 +205,7 @@ class TFCoreModelTesterMixin:
@slow
def test_mixed_precision(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -216,7 +216,7 @@ class TFCoreModelTesterMixin:
self.assertIsNotNone(outputs)
tf.keras.mixed_precision.experimental.set_policy("float32")
tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):