This commit is contained in:
Arthur Zucker 2024-03-25 21:57:31 +09:00
parent 8e9a2207b3
commit 00a09ed448

View File

@ -600,7 +600,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
unexpected_keys = []
missing_keys = []
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0: