mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Delete state_dict
to release memory as early as possible (#18832)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a26c752353
commit
563a8d58db
@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, prefix=""):
|
||||
def load(module: nn.Module, state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@ -434,9 +434,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
load(model_to_load, state_dict, prefix=start_prefix)
|
||||
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||
# it's safe to delete it.
|
||||
del state_dict
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user