diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 873df4aa86a..9e03a234b77 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1121,4 +1121,9 @@ - local: internal/time_series_utils title: Utilities for Time Series title: Internal helpers + - sections: + - local: reference/environment_variables + title: Environment Variables + title: Reference title: API + diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md new file mode 100644 index 00000000000..fc20c08f9e6 --- /dev/null +++ b/docs/source/en/reference/environment_variables.md @@ -0,0 +1,58 @@ + + +# Environment Variables + +## HF_ENABLE_PARALLEL_LOADING + +By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%. + +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. + +e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it. + +Profile before committing to using this environment variable, this will not produce speed ups for smaller models. + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" + +from transformers import pipeline + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` + +## HF_PARALLEL_LOADING_WORKERS + +Determines how many threads should be used when parallel loading is enabled. Default is `8`. + +If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files. + +e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned. + +Tune as you see fit. + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" +os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" + +from transformers import pipeline + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ced2231d131..58d422b37b8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -27,6 +27,7 @@ import shutil import tempfile import warnings from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index +def load_shard_file(args): + ( + shard_file, + state_dict, + disk_only_shard_files, + is_hqq_or_bnb, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_regex, + unexpected_keys, + device_mesh, + ) = args + + # Skip the load for shards that only contain disk-offloaded weights + if shard_file in disk_only_shard_files: + return [], disk_offload_index, cpu_offload_index + + map_location = "cpu" + if ( + shard_file.endswith(".safetensors") + and not is_hqq_or_bnb + and not (is_deepspeed_zero3_enabled() and not is_quantized) + ): + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and ( + hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) + ) + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + + # If shard_file is "", we use the existing state_dict instead of loading it + if shard_file != "": + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + # Fix the key names + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + + error_msgs = [] + + if is_deepspeed_zero3_enabled() and not is_quantized: + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) + # Skip it with fsdp on ranks other than 0 + elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_regex=keep_in_fp32_regex, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + ) + + return error_msgs, disk_offload_index, cpu_offload_index + + +def load_shard_files_with_threadpool(args_list): + num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) + + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + error_msgs = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_shard_file, arg) for arg in args_list] + for future in as_completed(futures): + result = future.result() + ( + _error_msgs, + disk_offload_index, + cpu_offload_index, + ) = result + + error_msgs += _error_msgs + + pbar.update(1) + + return error_msgs, disk_offload_index, cpu_offload_index + + def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: path, name = weights_name.rsplit(".", 1) @@ -4973,9 +5084,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} - # For nice tqdm bars - if checkpoint_files is not None and len(checkpoint_files) > 1: - checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") # To be able to iterate, even if we don't use it if the state_dict is already provided elif state_dict is not None: checkpoint_files = [""] @@ -4993,64 +5101,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer) + # Prepare and compatabilize arguments for serial and parallel shard loading + args_list = [ + ( + shard_file, + state_dict, + disk_only_shard_files, + is_hqq_or_bnb, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_regex, + unexpected_keys, + device_mesh, + ) + for shard_file in checkpoint_files + ] + error_msgs = [] - # Iterate on all the shards to load the weights - for shard_file in checkpoint_files: - # Skip the load for shards that only contain disk-offloaded weights - if shard_file in disk_only_shard_files: - continue - map_location = "cpu" - if ( - shard_file.endswith(".safetensors") - and not is_hqq_or_bnb - and not (is_deepspeed_zero3_enabled() and not is_quantized) - ): - map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and ( - hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) - ) - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + if ( + os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES + and not is_deepspeed_zero3_enabled() + ): + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list) + error_msgs += _error_msgs + else: + if len(args_list) > 1: + args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") - # If shard_file is "", we use the existing state_dict instead of loading it - if shard_file != "": - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only - ) - - # Fix the key names - state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - - if is_deepspeed_zero3_enabled() and not is_quantized: - error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) - # Skip it with fsdp on ranks other than 0 - elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - shard_file, - expected_keys, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - cpu_offload_folder=cpu_offload_folder, - cpu_offload_index=cpu_offload_index, - hf_quantizer=hf_quantizer, - is_safetensors=is_offloaded_safetensors, - keep_in_fp32_regex=keep_in_fp32_regex, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - ) - - # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop - del state_dict + for args in args_list: + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args) + error_msgs += _error_msgs # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 896b8771c41..ca4e1cc3d42 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -297,6 +297,27 @@ if is_torch_available(): hub.TRANSFORMERS_CACHE = transformers_cache +# Need to be serializable, which means they cannot be in a test class method +class TestGammaBetaNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.gamma = torch.nn.Parameter(torch.ones(1)) + self.beta = torch.nn.Parameter(torch.zeros(1)) + + def forward(self): + return self.gamma.sum() + self.beta.sum() + + +class TestModelGammaBeta(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.LayerNorm = TestGammaBetaNorm() + self.post_init() + + def forward(self): + return self.LayerNorm() + + if is_flax_available(): from transformers import FlaxBertModel @@ -1636,24 +1657,6 @@ class ModelUtilsTest(TestCasePlus): torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"]) def test_warning_for_beta_gamma_parameters(self): - class TestGammaBetaNorm(torch.nn.Module): - def __init__(self): - super().__init__() - self.gamma = torch.nn.Parameter(torch.ones(1)) - self.beta = torch.nn.Parameter(torch.zeros(1)) - - def forward(self): - return self.gamma.sum() + self.beta.sum() - - class TestModelGammaBeta(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.LayerNorm = TestGammaBetaNorm() - self.post_init() - - def forward(self): - return self.LayerNorm() - logger = logging.get_logger("transformers.modeling_utils") config = PretrainedConfig() warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"