Fix optimizer (#6717)

This commit is contained in:
Julien Plu 2020-08-26 17:12:44 +02:00 committed by GitHub
parent 77abd1e79f
commit 02e8cd5584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -221,9 +221,9 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
)
return tf.no_op()
def apply_gradients(self, grads_and_vars, name=None):
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
grads, tvars = list(zip(*grads_and_vars))
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name,)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""