mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Fix tests of mixed precision now that experimental is deprecated (#17300)
* Fix tests of mixed precision now that experimental is deprecated * Fix mixed precision in training_args_tf.py too
This commit is contained in:
parent
6d211429ec
commit
651e48e1e5
@ -195,8 +195,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
|
||||
# Set to float16 at first
|
||||
if self.fp16:
|
||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
||||
|
||||
if self.no_cuda:
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
@ -217,8 +216,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
if tpu:
|
||||
# Set to bfloat16 in case of TPU
|
||||
if self.fp16:
|
||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
|
||||
|
||||
tf.config.experimental_connect_to_cluster(tpu)
|
||||
tf.tpu.experimental.initialize_tpu_system(tpu)
|
||||
|
@ -205,7 +205,7 @@ class TFCoreModelTesterMixin:
|
||||
|
||||
@slow
|
||||
def test_mixed_precision(self):
|
||||
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
|
||||
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -216,7 +216,7 @@ class TFCoreModelTesterMixin:
|
||||
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
tf.keras.mixed_precision.experimental.set_policy("float32")
|
||||
tf.keras.mixed_precision.set_global_policy("float32")
|
||||
|
||||
@slow
|
||||
def test_train_pipeline_custom_model(self):
|
||||
|
Loading…
Reference in New Issue
Block a user