mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Better support for resuming training (#8878)
This commit is contained in:
parent
21db560df3
commit
7c10dd22ae
@ -665,12 +665,12 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", num_examples)
|
logger.info(f" Num examples = {num_examples}")
|
||||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||||
logger.info(" Total optimization steps = %d", max_steps)
|
logger.info(f" Total optimization steps = {max_steps}")
|
||||||
|
|
||||||
self.state.epoch = 0
|
self.state.epoch = 0
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
@ -680,13 +680,20 @@ class Trainer:
|
|||||||
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
||||||
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
||||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
if not self.args.ignore_data_skip:
|
||||||
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||||
|
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
|
||||||
|
else:
|
||||||
|
steps_trained_in_current_epoch = 0
|
||||||
|
|
||||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||||
logger.info(" Continuing training from global step %d", self.state.global_step)
|
logger.info(f" Continuing training from global step {self.state.global_step}")
|
||||||
logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch)
|
if not self.args.ignore_data_skip:
|
||||||
|
logger.info(
|
||||||
|
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
|
||||||
|
"batches in the first epoch."
|
||||||
|
)
|
||||||
|
|
||||||
# Update the references
|
# Update the references
|
||||||
self.callback_handler.model = self.model
|
self.callback_handler.model = self.model
|
||||||
@ -712,6 +719,13 @@ class Trainer:
|
|||||||
|
|
||||||
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
|
||||||
|
if not self.args.ignore_data_skip:
|
||||||
|
for epoch in range(epochs_trained):
|
||||||
|
# We just need to begin an iteration to create the randomization of the sampler.
|
||||||
|
for _ in train_dataloader:
|
||||||
|
break
|
||||||
|
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
@ -189,6 +189,10 @@ class TrainingArguments:
|
|||||||
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
If there are more than one devices, whether to use model parallelism to distribute the model's modules
|
If there are more than one devices, whether to use model parallelism to distribute the model's modules
|
||||||
across devices or not.
|
across devices or not.
|
||||||
|
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||||
|
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||||
|
step can take a long time) but will not yield the same results as the interrupted training would have.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -350,6 +354,12 @@ class TrainingArguments:
|
|||||||
greater_is_better: Optional[bool] = field(
|
greater_is_better: Optional[bool] = field(
|
||||||
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
|
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
|
||||||
)
|
)
|
||||||
|
ignore_data_skip: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
|
@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(b, b1)
|
self.assertEqual(b, b1)
|
||||||
self.assertEqual(state, state1)
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
|
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
|
# Reinitialize trainer and load model
|
||||||
|
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||||
|
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||||
|
|
||||||
|
trainer.train(model_path=checkpoint)
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(a, a1)
|
||||||
|
self.assertEqual(b, b1)
|
||||||
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
# With a regular model that is not a PreTrainedModel
|
# With a regular model that is not a PreTrainedModel
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(b, b1)
|
self.assertEqual(b, b1)
|
||||||
self.assertEqual(state, state1)
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
|
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
|
# Reinitialize trainer and load model
|
||||||
|
model = RegressionModel()
|
||||||
|
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||||
|
|
||||||
|
trainer.train(model_path=checkpoint)
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(a, a1)
|
||||||
|
self.assertEqual(b, b1)
|
||||||
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
def test_resume_training_with_gradient_accumulation(self):
|
def test_resume_training_with_gradient_accumulation(self):
|
||||||
if torch.cuda.device_count() > 2:
|
if torch.cuda.device_count() > 2:
|
||||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
Loading…
Reference in New Issue
Block a user