Fix TF Trainer loss calculation (#6998)

* create branch for issue #6968

* First attempt to fix incorrect tf trainer loss calculation

* Fix training loss in metric

* fix tf trainer evaluation loss

* apply count_instances_in_batch() for eval and test datasets

* prototype of using a new argument in trainer_tf.py to fix loss issue

* some renaming and fix, in particular for evaluation methods

* fix bugs to have a running version

* change to @staticmethod

* apply style
This commit is contained in:
Yih-Dar 2020-09-15 11:41:00 +02:00 committed by GitHub
parent b0cbcdb05b
commit cb061e78e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from packaging.version import parse
from tensorflow.python.distribute.values import PerReplica
from .integrations import is_comet_available, is_wandb_available
from .modeling_tf_utils import TFPreTrainedModel
@ -363,7 +364,7 @@ class TFTrainer:
else:
metrics = {}
metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
metrics["eval_loss"] = self.eval_loss.result().numpy() / steps
for key in list(metrics.keys()):
if not key.startswith("eval_"):
@ -441,21 +442,28 @@ class TFTrainer:
return output.metrics
def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
def prediction_step(
self, features: tf.Tensor, labels: tf.Tensor, nb_instances_in_global_batch: tf.Tensor
) -> tf.Tensor:
"""
Compute the prediction on features and update the loss with labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, logits = self.run_model(features, labels, False)
scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
self.eval_loss.update_state(per_example_loss)
self.eval_loss.update_state(scaled_loss)
return logits
@tf.function
def distributed_prediction_steps(self, batch):
logits = self.args.strategy.run(self.prediction_step, batch)
nb_instances_in_batch = self._compute_nb_instances(batch)
inputs = self._get_step_inputs(batch, nb_instances_in_batch)
logits = self.args.strategy.run(self.prediction_step, inputs)
return logits
@ -542,7 +550,7 @@ class TFTrainer:
self.distributed_training_steps(batch)
training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
training_loss = self.train_loss.result() / (step + 1)
if self.args.debug:
logs = {}
@ -592,14 +600,14 @@ class TFTrainer:
# Clean the state at the end of training
delattr(self, "_past")
def training_step(self, features, labels):
def training_step(self, features, labels, nb_instances_in_global_batch):
"""
Perform a training step on features and labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss, _ = self.run_model(features, labels, True)
scaled_loss = per_example_loss / self.total_train_batch_size
scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
gradients = [
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
@ -608,14 +616,14 @@ class TFTrainer:
if self.args.gradient_accumulation_steps > 1:
self.gradient_accumulator(gradients)
self.train_loss.update_state(per_example_loss)
self.train_loss.update_state(scaled_loss)
if self.args.gradient_accumulation_steps == 1:
return gradients
def apply_gradients(self, features, labels):
def apply_gradients(self, features, labels, nb_instances_in_global_batch):
if self.args.gradient_accumulation_steps == 1:
gradients = self.training_step(features, labels)
gradients = self.training_step(features, labels, nb_instances_in_global_batch)
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else:
@ -625,7 +633,7 @@ class TFTrainer:
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels)
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
features = {
k: tf.concat(
@ -650,7 +658,35 @@ class TFTrainer:
@tf.function
def distributed_training_steps(self, batch):
with self.args.strategy.scope():
self.args.strategy.run(self.apply_gradients, batch)
nb_instances_in_batch = self._compute_nb_instances(batch)
inputs = self._get_step_inputs(batch, nb_instances_in_batch)
self.args.strategy.run(self.apply_gradients, inputs)
@staticmethod
def _compute_nb_instances(batch):
labels = batch[-1]
if isinstance(labels, PerReplica):
labels = tf.concat(labels.values, axis=0)
nb_instances = tf.reduce_sum(tf.cast(labels != -100, dtype=tf.int32))
return nb_instances
@staticmethod
def _get_step_inputs(batch, nb_instances):
features, labels = batch
if isinstance(labels, PerReplica):
# need to make a `PerReplica` objects for ``nb_instances``
nb_instances = PerReplica([nb_instances] * len(labels.values))
step_inputs = (features, labels, nb_instances)
return step_inputs
def run_model(self, features, labels, training):
"""