From f13f1f8fb8a89d6405b49b0981c5350ebd52430c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 11 May 2021 12:02:48 -0400 Subject: [PATCH] Test checkpointing (#11682) * Add test and see where CI is unhappy * Load with strict=False --- src/transformers/trainer.py | 13 ++++++++++++- tests/test_modeling_common.py | 7 +++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 934b55d0c09..8d79fe14ec9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1059,7 +1059,18 @@ class Trainer: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") # If the model is on the GPU, it still works! - self.model.load_state_dict(state_dict) + load_result = self.model.load_state_dict(state_dict, strict=False) + if len(load_result.missing_keys) != 0: + if load_result.missing_keys == self.model._keys_to_ignore_on_save: + self.model.tie_weights() + else: + logger.warn( + f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}." + ) + if len(load_result.unexpected_keys) != 0: + logger.warn( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 19469075adc..00b8080ff90 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -177,6 +177,13 @@ class ModelTesterMixin: for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved) + # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. + load_result = model.load_state_dict(state_dict_saved, strict=False) + self.assertTrue( + len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save + ) + self.assertTrue(len(load_result.unexpected_keys) == 0) + def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(3)