diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 97dc1d3c00a..77f0bc117e9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2629,7 +2629,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step if low_cpu_mem_usage: for key in missing_keys: - if key.startswith(prefix): + if key in list(model_state_dict.keys()): + key = key + elif f"{prefix}.key" in list(model_state_dict.keys()): + key = f"{prefix}.key" + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a05d729a18c..1de50c8f90f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3166,6 +3166,27 @@ class ModelUtilsTest(TestCasePlus): ): _ = ModelWithHead.from_pretrained(tmp_dir) + @require_torch_gpu + def test_pretrained_low_mem_new_config(self): + # Checking for 1 model(the same one which was described in the issue) . + model_ids = ["gpt2"] + + for model_id in model_ids: + model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id) + model_config.n_layer = 48 + model_config.n_head = 25 + model_config.n_embd = 1600 + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_id, + config=model_config, + ignore_mismatched_sizes=True, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ) + model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id) + + self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) + @require_torch @is_staging_test