[Modeling] Reduce runtime when loading missing keys (#36312)

* hoist keys

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove hoist

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers 2025-02-24 11:10:28 -05:00 committed by GitHub
parent 18276b03f7
commit 05dfed06d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4679,11 +4679,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key in list(model_state_dict.keys()):
if key in model_state_dict:
key = key
elif f"{prefix}.{key}" in list(model_state_dict.keys()):
elif f"{prefix}.{key}" in model_state_dict:
key = f"{prefix}.{key}"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in model_state_dict:
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]