Delete state_dict to release memory as early as possible (#18832)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-09-01 10:55:30 +02:00 committed by GitHub
parent a26c752353
commit 563a8d58db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
@ -434,9 +434,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(child, state_dict, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs