mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add util for ram efficient loading of model when using fsdp (#25107)
* add util for ram efficient loading of model when using fsdp
* make fix-copies
* fixes 😅
* docs
* making it further easier to use
* rename the function
* refactor to handle fsdp ram efficiency in `from_pretrained`
* fixes
* fixes
* fixes
* update
* fixes
* revert `load_pretrained_model_only_on_rank0`
* resolve `load_from_checkpoint`
This commit is contained in:
parent
4e1dee0e8e
commit
c4c0ceff09
@ -73,6 +73,7 @@ from .utils import (
|
||||
is_torch_tpu_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
|
||||
@ -106,6 +107,14 @@ logger = logging.get_logger(__name__)
|
||||
_init_weights = True
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
|
||||
|
||||
|
||||
def is_fsdp_enabled_and_dist_rank_0():
|
||||
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
|
||||
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
from smdistributed.modelparallel import __version__ as SMP_VERSION
|
||||
@ -458,7 +467,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
try:
|
||||
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
|
||||
if (
|
||||
(is_deepspeed_zero3_enabled() or is_fsdp_enabled)
|
||||
and torch.distributed.is_initialized()
|
||||
and torch.distributed.get_rank() > 0
|
||||
):
|
||||
map_location = "meta"
|
||||
else:
|
||||
map_location = "cpu"
|
||||
@ -2283,6 +2296,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
@ -3238,7 +3254,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_buffers = {".".join([prefix, key]) for key in model_buffers}
|
||||
unexpected_keys = list(unexpected_keys - model_buffers)
|
||||
|
||||
if device_map is None:
|
||||
model.tie_weights()
|
||||
if device_map is None and not is_fsdp_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
id_tensor = id_tensor_storage(tensor)
|
||||
@ -3443,23 +3460,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
loaded_keys,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
state_dict_index=state_dict_index,
|
||||
dtype=dtype,
|
||||
is_quantized=is_quantized,
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
loaded_keys,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
state_dict_index=state_dict_index,
|
||||
dtype=dtype,
|
||||
is_quantized=is_quantized,
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
for key, param in model_to_load.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
if not (is_quantized):
|
||||
set_module_tensor_to_device(
|
||||
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
set_module_quantized_tensor_to_device(
|
||||
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
|
||||
|
@ -465,10 +465,6 @@ class Trainer:
|
||||
):
|
||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||
|
||||
self.forward_prefetch = False
|
||||
if self.args.fsdp_config.get("forward_prefetch", False):
|
||||
self.forward_prefetch = True
|
||||
|
||||
self.limit_all_gathers = False
|
||||
if self.args.fsdp_config.get("limit_all_gathers", False):
|
||||
self.limit_all_gathers = True
|
||||
@ -1379,12 +1375,12 @@ class Trainer:
|
||||
auto_wrapper_callable = None
|
||||
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
|
||||
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
|
||||
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
|
||||
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
|
||||
)
|
||||
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
if self.args.fsdp_config["min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
|
||||
)
|
||||
elif fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
@ -1517,7 +1513,12 @@ class Trainer:
|
||||
if resume_from_checkpoint is None:
|
||||
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
|
||||
|
||||
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
|
||||
if (
|
||||
resume_from_checkpoint is not None
|
||||
and not is_sagemaker_mp_enabled()
|
||||
and not self.is_deepspeed_enabled
|
||||
and not self.is_fsdp_enabled
|
||||
):
|
||||
self._load_from_checkpoint(resume_from_checkpoint)
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
@ -1651,7 +1652,7 @@ class Trainer:
|
||||
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
|
||||
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
||||
if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
|
||||
self._load_from_checkpoint(resume_from_checkpoint, model)
|
||||
|
||||
# as the model is wrapped, don't use `accelerator.prepare`
|
||||
@ -3886,7 +3887,6 @@ class Trainer:
|
||||
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
|
||||
"limit_all_gathers", fsdp_plugin.limit_all_gathers
|
||||
)
|
||||
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if getattr(self.args, "hf_deepspeed_config", None) is None:
|
||||
|
@ -436,13 +436,13 @@ class TrainingArguments:
|
||||
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
|
||||
|
||||
A List of config and its options:
|
||||
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
|
||||
- min_num_params (`int`, *optional*, defaults to `0`):
|
||||
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
|
||||
passed).
|
||||
- fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
|
||||
- transformer_layer_cls_to_wrap (`List[str]`, *optional*):
|
||||
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
|
||||
`T5Block` .... (useful only when `fsdp` flag is passed).
|
||||
- fsdp_backward_prefetch (`str`, *optional*)
|
||||
- backward_prefetch (`str`, *optional*)
|
||||
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
|
||||
`fsdp` field is passed).
|
||||
|
||||
@ -454,7 +454,7 @@ class TrainingArguments:
|
||||
- `"backward_post"` : This prefetches the next set of parameters after the current set of
|
||||
parameter’s
|
||||
gradient computation.
|
||||
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
|
||||
- forward_prefetch (`bool`, *optional*, defaults to `False`)
|
||||
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
|
||||
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
|
||||
forward pass.
|
||||
@ -462,6 +462,14 @@ class TrainingArguments:
|
||||
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
|
||||
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
|
||||
all-gathers.
|
||||
- use_orig_params (`bool`, *optional*, defaults to `False`)
|
||||
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
|
||||
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
|
||||
refer this
|
||||
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
|
||||
- sync_module_states (`bool`, *optional*, defaults to `True`)
|
||||
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
|
||||
ensure they are the same across all ranks after initialization
|
||||
- xla (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
|
||||
and its API may evolve in the future.
|
||||
@ -1520,44 +1528,44 @@ class TrainingArguments:
|
||||
self.fsdp_config = {}
|
||||
|
||||
if isinstance(self.fsdp_config, str):
|
||||
if len(self.fsdp) == 0:
|
||||
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
|
||||
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
|
||||
self.fsdp_config = json.load(f)
|
||||
for k, v in self.fsdp_config.items():
|
||||
if k.startswith("fsdp_"):
|
||||
self.fsdp_config[k.replace("fsdp_", "")] = v
|
||||
del self.fsdp_config[k]
|
||||
|
||||
if self.fsdp_min_num_params > 0:
|
||||
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
|
||||
|
||||
self.fsdp_config["fsdp_min_num_params"] = max(
|
||||
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
|
||||
)
|
||||
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
|
||||
|
||||
# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
|
||||
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||
]
|
||||
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
|
||||
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
|
||||
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
|
||||
|
||||
if self.fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
warnings.warn(
|
||||
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
|
||||
)
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
|
||||
"fsdp_transformer_layer_cls_to_wrap", []
|
||||
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
|
||||
"transformer_layer_cls_to_wrap", []
|
||||
) + [self.fsdp_transformer_layer_cls_to_wrap]
|
||||
|
||||
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
|
||||
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
|
||||
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
|
||||
|
||||
if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
|
||||
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
|
||||
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
|
||||
|
||||
if (
|
||||
len(self.fsdp) > 0
|
||||
and self.fsdp_config["fsdp_min_num_params"] > 0
|
||||
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
|
||||
and self.fsdp_config["min_num_params"] > 0
|
||||
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
|
||||
)
|
||||
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
|
||||
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
|
||||
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
|
||||
if self.fsdp_config["xla"]:
|
||||
@ -1583,23 +1591,29 @@ class TrainingArguments:
|
||||
FSDP_SHARDING_STRATEGY,
|
||||
)
|
||||
|
||||
prefix = "FSDP_"
|
||||
for fsdp_option in self.fsdp:
|
||||
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||
# set environment variable for FSDP sharding strategy
|
||||
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
|
||||
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
|
||||
)
|
||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
|
||||
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
if self.fsdp_config["min_num_params"] > 0:
|
||||
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
|
||||
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
|
||||
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||
self.fsdp_config["transformer_layer_cls_to_wrap"]
|
||||
)
|
||||
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
|
||||
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
|
||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
|
Loading…
Reference in New Issue
Block a user