Rename variables with unclear naming (#14122)

* Rename var

* Add comments
This commit is contained in:
Li-Huai (Allan) Lin 2021-10-23 01:05:45 +08:00 committed by GitHub
parent 05a2afc252
commit 62ccbe0960
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1494,13 +1494,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix = not has_prefix_module and expects_prefix_module
add_prefix = has_prefix_module and not expects_prefix_module
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix:
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif add_prefix:
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
@ -1512,9 +1512,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix:
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix:
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
@ -1539,7 +1541,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _fast_init:
# retrieve unintialized modules and initialize
uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
)
for module in uninitialized_modules:
model._init_weights(module)