Better None gradients handling in TF Trainer (#4469)

* Better None gradients handling

* Apply Style

* Apply Style
This commit is contained in:
Julien Plu 2020-05-20 22:46:21 +02:00 committed by GitHub
parent e708bb75bf
commit fa2fbed3e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,7 +141,7 @@ class TFTrainer:
self.optimizer = tf.keras.optimizers.get( self.optimizer = tf.keras.optimizers.get(
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}} {"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
) )
logger.info("Created an/a {} optimizer".format(self.optimizer)) logger.info("Created an/a {} optimizer".format(self.args.optimizer_name))
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None: def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
""" """
@ -335,12 +335,8 @@ class TFTrainer:
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
] ]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients] gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
vars = self.model.trainable_variables
if self.args.mode in ["token-classification", "question-answering"]: self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
self.optimizer.apply_gradients(list(zip(gradients, vars)))
self.gradient_accumulator.reset() self.gradient_accumulator.reset()
def _accumulate_next_gradients(self): def _accumulate_next_gradients(self):
@ -375,12 +371,10 @@ class TFTrainer:
def _forward(self, features, labels): def _forward(self, features, labels):
"""Forwards a training example and accumulates the gradients.""" """Forwards a training example and accumulates the gradients."""
per_example_loss, _ = self._run_model(features, labels, True) per_example_loss, _ = self._run_model(features, labels, True)
vars = self.model.trainable_variables gradients = tf.gradients(per_example_loss, self.model.trainable_variables)
gradients = [
if self.args.mode in ["token-classification", "question-answering"]: g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name] ]
gradients = self.optimizer.get_gradients(per_example_loss, vars)
self.gradient_accumulator(gradients) self.gradient_accumulator(gradients)