mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Delete state_dict
to release memory as early as possible (#18832)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a26c752353
commit
563a8d58db
@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
# so we need to apply the function recursively.
|
# so we need to apply the function recursively.
|
||||||
def load(module: nn.Module, prefix=""):
|
def load(module: nn.Module, state_dict, prefix=""):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
@ -434,9 +434,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|||||||
|
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + ".")
|
load(child, state_dict, prefix + name + ".")
|
||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
load(model_to_load, state_dict, prefix=start_prefix)
|
||||||
|
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||||
|
# it's safe to delete it.
|
||||||
|
del state_dict
|
||||||
|
|
||||||
return error_msgs
|
return error_msgs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user