diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 606a137a3ef..123d1ff8d02 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1523,10 +1523,6 @@ class Trainer: if self.is_world_process_zero(): self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) - # Maybe delete some older checkpoints. - if self.is_world_process_zero(): - self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) - # Save RNG state in non-distributed training rng_states = { "python": random.getstate(), @@ -1552,6 +1548,10 @@ class Trainer: else: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) + # Maybe delete some older checkpoints. + if self.is_world_process_zero(): + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: @@ -1924,7 +1924,7 @@ class Trainer: ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) else: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) - if regex_match and regex_match.groups(): + if regex_match is not None and regex_match.groups() is not None: ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) @@ -1932,10 +1932,8 @@ class Trainer: # Make sure we don't delete the best model. if self.state.best_model_checkpoint is not None: best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) - checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = ( - checkpoints_sorted[-1], - checkpoints_sorted[best_model_index], - ) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: @@ -1947,7 +1945,17 @@ class Trainer: if len(checkpoints_sorted) <= self.args.save_total_limit: return - number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) + # If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") diff --git a/tests/test_trainer.py b/tests/test_trainer.py index eca71a39fb7..e1933804c24 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,6 +21,7 @@ import random import re import tempfile import unittest +from pathlib import Path import numpy as np @@ -45,6 +46,7 @@ from transformers.testing_utils import ( require_torch_multi_gpu, slow, ) +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils.hp_naming import TrialShortNamer @@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.assertTrue(isinstance(trainer.state.total_flos, float)) + def check_checkpoint_deletion(self, trainer, output_dir, expected): + # Make fake checkpoints + for n in [5, 10, 15, 20, 25]: + os.makedirs(os.path.join(output_dir, f"{PREFIX_CHECKPOINT_DIR}-{n}"), exist_ok=True) + trainer._rotate_checkpoints(output_dir=output_dir) + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")] + values = [int(re.match(f".*{PREFIX_CHECKPOINT_DIR}-([0-9]+)", d).groups()[0]) for d in glob_checkpoints] + self.assertSetEqual(set(values), set(expected)) + + def test_checkpoint_rotation(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # Without best model at end + trainer = get_regression_trainer(output_dir=tmp_dir, save_total_limit=2) + self.check_checkpoint_deletion(trainer, tmp_dir, [20, 25]) + + # With best model at end + trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=2) + trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5") + self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25]) + + # Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume + # from checkpoint + trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=1) + trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25") + self.check_checkpoint_deletion(trainer, tmp_dir, [25]) + + trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5") + self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25]) + def check_mem_metrics(self, trainer, check_func): metrics = trainer.train().metrics check_func("init_mem_cpu_alloc_delta", metrics)