mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix optimizer (#6717)
This commit is contained in:
parent
77abd1e79f
commit
02e8cd5584
@ -221,9 +221,9 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
)
|
||||
return tf.no_op()
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
|
||||
grads, tvars = list(zip(*grads_and_vars))
|
||||
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name,)
|
||||
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
|
||||
|
||||
def _get_lr(self, var_device, var_dtype, apply_state):
|
||||
"""Retrieves the learning rate with the given state."""
|
||||
|
Loading…
Reference in New Issue
Block a user