mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix AdamWeightDecay for TF 2.11 (#20735)
* Fix AdamWeightDecay for TF * Fix AdamWeightDecay for TF * make fixup
This commit is contained in:
parent
a12c5cbcd8
commit
4f1788b34d
@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"):
|
||||||
|
Adam = tf.keras.optimizer.legacy.Adam
|
||||||
|
else:
|
||||||
|
Adam = tf.keras.optimizers.Adam
|
||||||
|
|
||||||
|
|
||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
"""
|
"""
|
||||||
Applies a warmup schedule on a given learning rate decay schedule.
|
Applies a warmup schedule on a given learning rate decay schedule.
|
||||||
@ -163,7 +169,7 @@ def create_optimizer(
|
|||||||
return optimizer, lr_schedule
|
return optimizer, lr_schedule
|
||||||
|
|
||||||
|
|
||||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
class AdamWeightDecay(Adam):
|
||||||
"""
|
"""
|
||||||
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
|
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
|
||||||
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
|
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
|
||||||
|
Loading…
Reference in New Issue
Block a user