From 05dfed06d780a24623ad9229e4746e7b5265135b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 24 Feb 2025 11:10:28 -0500 Subject: [PATCH] [Modeling] Reduce runtime when loading missing keys (#36312) * hoist keys Signed-off-by: Kyle Sayers * remove hoist Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers --- src/transformers/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0d03de2addd..3ebd0eacfa6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4679,11 +4679,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step if low_cpu_mem_usage: for key in missing_keys: - if key in list(model_state_dict.keys()): + if key in model_state_dict: key = key - elif f"{prefix}.{key}" in list(model_state_dict.keys()): + elif f"{prefix}.{key}" in model_state_dict: key = f"{prefix}.{key}" - elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in model_state_dict: key = ".".join(key.split(".")[1:]) param = model_state_dict[key]