mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix tests of mixed precision now that experimental is deprecated (#17300)
* Fix tests of mixed precision now that experimental is deprecated * Fix mixed precision in training_args_tf.py too
This commit is contained in:
parent
6d211429ec
commit
651e48e1e5
@ -195,8 +195,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
|
|
||||||
# Set to float16 at first
|
# Set to float16 at first
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
||||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
|
||||||
|
|
||||||
if self.no_cuda:
|
if self.no_cuda:
|
||||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||||
@ -217,8 +216,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
if tpu:
|
if tpu:
|
||||||
# Set to bfloat16 in case of TPU
|
# Set to bfloat16 in case of TPU
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
|
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
|
||||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
|
||||||
|
|
||||||
tf.config.experimental_connect_to_cluster(tpu)
|
tf.config.experimental_connect_to_cluster(tpu)
|
||||||
tf.tpu.experimental.initialize_tpu_system(tpu)
|
tf.tpu.experimental.initialize_tpu_system(tpu)
|
||||||
|
@ -205,7 +205,7 @@ class TFCoreModelTesterMixin:
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_mixed_precision(self):
|
def test_mixed_precision(self):
|
||||||
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
|
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -216,7 +216,7 @@ class TFCoreModelTesterMixin:
|
|||||||
|
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
tf.keras.mixed_precision.experimental.set_policy("float32")
|
tf.keras.mixed_precision.set_global_policy("float32")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_train_pipeline_custom_model(self):
|
def test_train_pipeline_custom_model(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user