From 563a8d58db0acd088f62167f23671ba2f88bae9c Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:55:30 +0200 Subject: [PATCH] Delete `state_dict` to release memory as early as possible (#18832) Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: ydshieh --- src/transformers/modeling_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d77258c94ea..04196633e14 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: nn.Module, prefix=""): + def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) if is_deepspeed_zero3_enabled(): @@ -434,9 +434,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): for name, child in module._modules.items(): if child is not None: - load(child, prefix + name + ".") + load(child, state_dict, prefix + name + ".") - load(model_to_load, prefix=start_prefix) + load(model_to_load, state_dict, prefix=start_prefix) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict return error_msgs