diff --git a/src/transformers/optimization_tf.py b/src/transformers/optimization_tf.py index e2b2a961ca1..58ff287d8bf 100644 --- a/src/transformers/optimization_tf.py +++ b/src/transformers/optimization_tf.py @@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union import tensorflow as tf +if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"): + Adam = tf.keras.optimizer.legacy.Adam +else: + Adam = tf.keras.optimizers.Adam + + class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): """ Applies a warmup schedule on a given learning rate decay schedule. @@ -163,7 +169,7 @@ def create_optimizer( return optimizer, lr_schedule -class AdamWeightDecay(tf.keras.optimizers.Adam): +class AdamWeightDecay(Adam): """ Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact