Better support for resuming training (#8878)

This commit is contained in:
Sylvain Gugger 2020-12-01 13:45:21 -05:00 committed by GitHub
parent 21db560df3
commit 7c10dd22ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 11 deletions

View File

@ -665,12 +665,12 @@ class Trainer:
) )
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", num_examples) logger.info(f" Num examples = {num_examples}")
logger.info(" Num Epochs = %d", num_train_epochs) logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(" Total optimization steps = %d", max_steps) logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0 self.state.epoch = 0
epochs_trained = 0 epochs_trained = 0
@ -680,13 +680,20 @@ class Trainer:
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) if not self.args.ignore_data_skip:
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(" Continuing training from global step %d", self.state.global_step) logger.info(f" Continuing training from global step {self.state.global_step}")
logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch) if not self.args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
)
# Update the references # Update the references
self.callback_handler.model = self.model self.callback_handler.model = self.model
@ -712,6 +719,13 @@ class Trainer:
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip:
for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)

View File

@ -189,6 +189,10 @@ class TrainingArguments:
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there are more than one devices, whether to use model parallelism to distribute the model's modules If there are more than one devices, whether to use model parallelism to distribute the model's modules
across devices or not. across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
""" """
output_dir: str = field( output_dir: str = field(
@ -350,6 +354,12 @@ class TrainingArguments:
greater_is_better: Optional[bool] = field( greater_is_better: Optional[bool] = field(
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
) )
ignore_data_skip: bool = field(
default=False,
metadata={
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:

View File

@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# With a regular model that is not a PreTrainedModel # With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
model.load_state_dict(state_dict)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2: if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of