Fix TFTrainer prediction output (#9662)

* Fix TFTrainer prediction output

* Update trainer_tf.py

* Fix TFTrainer prediction output

* Fix evaluation_loss update in TFTrainer

* Fix TFTrainer prediction output
This commit is contained in:
Maria Janina Sarol 2021-01-25 03:27:12 -06:00 committed by GitHub
parent 9152f16023
commit 6312fed47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,6 +101,7 @@ class TFTrainer:
self.gradient_accumulator = GradientAccumulator()
self.global_step = 0
self.epoch_logging = 0
self.eval_loss = tf.keras.metrics.Sum()
if tb_writer is not None:
self.tb_writer = tb_writer
@ -202,13 +203,8 @@ class TFTrainer:
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
approx = math.floor if self.args.dataloader_drop_last else math.ceil
steps = approx(num_examples / self.args.eval_batch_size)
ds = (
test_dataset.repeat()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
steps = math.ceil(num_examples / self.args.eval_batch_size)
ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
@ -300,12 +296,14 @@ class TFTrainer:
)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)
logger.info(" Num examples in dataset = %d", num_examples)
if description == "Evaluation":
logger.info(" Num examples in used in evaluation = %d", self.args.eval_batch_size * steps)
logger.info(" Batch size = %d", self.args.eval_batch_size)
label_ids: np.ndarray = None
preds: np.ndarray = None
self.eval_loss = tf.keras.metrics.Sum()
self.eval_loss.reset_states()
# Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0:
@ -345,7 +343,7 @@ class TFTrainer:
else:
label_ids = np.append(label_ids, labels.numpy(), axis=0)
if step == steps:
if step == steps - 1:
break
if self.compute_metrics is not None and preds is not None and label_ids is not None: