mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TFTrainer prediction output (#9662)
* Fix TFTrainer prediction output * Update trainer_tf.py * Fix TFTrainer prediction output * Fix evaluation_loss update in TFTrainer * Fix TFTrainer prediction output
This commit is contained in:
parent
9152f16023
commit
6312fed47d
@ -101,6 +101,7 @@ class TFTrainer:
|
||||
self.gradient_accumulator = GradientAccumulator()
|
||||
self.global_step = 0
|
||||
self.epoch_logging = 0
|
||||
self.eval_loss = tf.keras.metrics.Sum()
|
||||
|
||||
if tb_writer is not None:
|
||||
self.tb_writer = tb_writer
|
||||
@ -202,13 +203,8 @@ class TFTrainer:
|
||||
if num_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
|
||||
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
||||
steps = approx(num_examples / self.args.eval_batch_size)
|
||||
ds = (
|
||||
test_dataset.repeat()
|
||||
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||
)
|
||||
steps = math.ceil(num_examples / self.args.eval_batch_size)
|
||||
ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
|
||||
|
||||
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
|
||||
|
||||
@ -300,12 +296,14 @@ class TFTrainer:
|
||||
)
|
||||
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Num examples in dataset = %d", num_examples)
|
||||
if description == "Evaluation":
|
||||
logger.info(" Num examples in used in evaluation = %d", self.args.eval_batch_size * steps)
|
||||
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
||||
|
||||
label_ids: np.ndarray = None
|
||||
preds: np.ndarray = None
|
||||
self.eval_loss = tf.keras.metrics.Sum()
|
||||
self.eval_loss.reset_states()
|
||||
|
||||
# Reset the past mems state at the beginning of the evaluation if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
@ -345,7 +343,7 @@ class TFTrainer:
|
||||
else:
|
||||
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
||||
|
||||
if step == steps:
|
||||
if step == steps - 1:
|
||||
break
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user