Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)

* Get parallel loader working. Include tests.

* Update the tests for parallel loading

* Rename env variables.

* Add docs for parallel model weight loading.

* Touch up parallel model loading docs.

* Touch up parallel model loading docs again.

* Edit comment in test_modeling_utils_parallel_loading.py

* Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py

* Correct times for parallelized loading, previous times were for a "hot" filesystem

* Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule.

* Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally.

* Fix style on model loading parallelism changes.

* Merge latest version of master's modeling_utils.

* Removed unused variable.

* Fix argument packing for the parallel loader.

* Fix state dict being undefined in the parallel model loader.

* Rename variables used in parallel model loading for clarity. Use get_module_from_name().

* Switch to the use of threads for parallel model loading.

* Update docs for parallel loading.

* Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting.

* Move parallelized shard loading into its own function.

* Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING.

* Update copyright to 2025 in readme for paralell model loading.

* Remove garbage collection line in load_shard_file, implicit garbage collection already occurs.

* Run formatter on modeling_utils.py

* Apply style fixes

* Delete tests/utils/test_modeling_utils_parallel_loading.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Aaron V 2025-05-23 12:39:47 -04:00 committed by GitHub
parent 1ed19360b1
commit d5f992f5e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 234 additions and 76 deletions

View File

@ -1121,4 +1121,9 @@
- local: internal/time_series_utils
title: Utilities for Time Series
title: Internal helpers
- sections:
- local: reference/environment_variables
title: Environment Variables
title: Reference
title: API

View File

@ -0,0 +1,58 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Environment Variables
## HF_ENABLE_PARALLEL_LOADING
By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%.
Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`.
e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it.
Profile before committing to using this environment variable, this will not produce speed ups for smaller models.
```py
import os
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
from transformers import pipeline
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```
## HF_PARALLEL_LOADING_WORKERS
Determines how many threads should be used when parallel loading is enabled. Default is `8`.
If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files.
e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned.
Tune as you see fit.
```py
import os
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4"
from transformers import pipeline
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```

View File

@ -27,6 +27,7 @@ import shutil
import tempfile
import warnings
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model(
return disk_offload_index, cpu_offload_index
def load_shard_file(args):
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
) = args
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
return [], disk_offload_index, cpu_offload_index
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
error_msgs = []
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
return error_msgs, disk_offload_index, cpu_offload_index
def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
error_msgs = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
(
_error_msgs,
disk_offload_index,
cpu_offload_index,
) = result
error_msgs += _error_msgs
pbar.update(1)
return error_msgs, disk_offload_index, cpu_offload_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
path, name = weights_name.rsplit(".", 1)
@ -4973,9 +5084,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {}
# For nice tqdm bars
if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
@ -4993,64 +5101,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
# Prepare and compatabilize arguments for serial and parallel shard loading
args_list = [
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
)
for shard_file in checkpoint_files
]
error_msgs = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
continue
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
if (
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
and not is_deepspeed_zero3_enabled()
):
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
error_msgs += _error_msgs
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
for args in args_list:
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
error_msgs += _error_msgs
# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:

View File

@ -297,6 +297,27 @@ if is_torch_available():
hub.TRANSFORMERS_CACHE = transformers_cache
# Need to be serializable, which means they cannot be in a test class method
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.LayerNorm()
if is_flax_available():
from transformers import FlaxBertModel
@ -1636,24 +1657,6 @@ class ModelUtilsTest(TestCasePlus):
torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])
def test_warning_for_beta_gamma_parameters(self):
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.LayerNorm()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"