mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[Modeling] Reduce runtime when loading missing keys (#36312)
* hoist keys Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove hoist Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
parent
18276b03f7
commit
05dfed06d7
@ -4679,11 +4679,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
|
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
for key in missing_keys:
|
for key in missing_keys:
|
||||||
if key in list(model_state_dict.keys()):
|
if key in model_state_dict:
|
||||||
key = key
|
key = key
|
||||||
elif f"{prefix}.{key}" in list(model_state_dict.keys()):
|
elif f"{prefix}.{key}" in model_state_dict:
|
||||||
key = f"{prefix}.{key}"
|
key = f"{prefix}.{key}"
|
||||||
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
|
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in model_state_dict:
|
||||||
key = ".".join(key.split(".")[1:])
|
key = ".".join(key.split(".")[1:])
|
||||||
param = model_state_dict[key]
|
param = model_state_dict[key]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user