fix steps_in_epoch variable when using max_steps

This commit is contained in:
wujindou 2021-02-03 11:00:53 +08:00
parent 71bdc076dd
commit 0e455b0112

View File

@ -910,7 +910,11 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
steps_in_epoch = (
len(epoch_iterator)
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):