diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index d6c92ebc056..ed5d34a8bf0 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -101,6 +101,7 @@ class TFTrainer: self.gradient_accumulator = GradientAccumulator() self.global_step = 0 self.epoch_logging = 0 + self.eval_loss = tf.keras.metrics.Sum() if tb_writer is not None: self.tb_writer = tb_writer @@ -202,13 +203,8 @@ class TFTrainer: if num_examples < 0: raise ValueError("The training dataset must have an asserted cardinality") - approx = math.floor if self.args.dataloader_drop_last else math.ceil - steps = approx(num_examples / self.args.eval_batch_size) - ds = ( - test_dataset.repeat() - .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last) - .prefetch(tf.data.experimental.AUTOTUNE) - ) + steps = math.ceil(num_examples / self.args.eval_batch_size) + ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE) return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples @@ -300,12 +296,14 @@ class TFTrainer: ) logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", num_examples) + logger.info(" Num examples in dataset = %d", num_examples) + if description == "Evaluation": + logger.info(" Num examples in used in evaluation = %d", self.args.eval_batch_size * steps) logger.info(" Batch size = %d", self.args.eval_batch_size) label_ids: np.ndarray = None preds: np.ndarray = None - self.eval_loss = tf.keras.metrics.Sum() + self.eval_loss.reset_states() # Reset the past mems state at the beginning of the evaluation if necessary. if self.args.past_index >= 0: @@ -345,7 +343,7 @@ class TFTrainer: else: label_ids = np.append(label_ids, labels.numpy(), axis=0) - if step == steps: + if step == steps - 1: break if self.compute_metrics is not None and preds is not None and label_ids is not None: