mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)
* Get parallel loader working. Include tests. * Update the tests for parallel loading * Rename env variables. * Add docs for parallel model weight loading. * Touch up parallel model loading docs. * Touch up parallel model loading docs again. * Edit comment in test_modeling_utils_parallel_loading.py * Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py * Correct times for parallelized loading, previous times were for a "hot" filesystem * Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule. * Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally. * Fix style on model loading parallelism changes. * Merge latest version of master's modeling_utils. * Removed unused variable. * Fix argument packing for the parallel loader. * Fix state dict being undefined in the parallel model loader. * Rename variables used in parallel model loading for clarity. Use get_module_from_name(). * Switch to the use of threads for parallel model loading. * Update docs for parallel loading. * Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting. * Move parallelized shard loading into its own function. * Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING. * Update copyright to 2025 in readme for paralell model loading. * Remove garbage collection line in load_shard_file, implicit garbage collection already occurs. * Run formatter on modeling_utils.py * Apply style fixes * Delete tests/utils/test_modeling_utils_parallel_loading.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
parent
1ed19360b1
commit
d5f992f5e6
@ -1121,4 +1121,9 @@
|
||||
- local: internal/time_series_utils
|
||||
title: Utilities for Time Series
|
||||
title: Internal helpers
|
||||
- sections:
|
||||
- local: reference/environment_variables
|
||||
title: Environment Variables
|
||||
title: Reference
|
||||
title: API
|
||||
|
||||
|
58
docs/source/en/reference/environment_variables.md
Normal file
58
docs/source/en/reference/environment_variables.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Environment Variables
|
||||
|
||||
## HF_ENABLE_PARALLEL_LOADING
|
||||
|
||||
By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%.
|
||||
|
||||
Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`.
|
||||
|
||||
e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it.
|
||||
|
||||
Profile before committing to using this environment variable, this will not produce speed ups for smaller models.
|
||||
|
||||
```py
|
||||
import os
|
||||
|
||||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
|
||||
```
|
||||
|
||||
## HF_PARALLEL_LOADING_WORKERS
|
||||
|
||||
Determines how many threads should be used when parallel loading is enabled. Default is `8`.
|
||||
|
||||
If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files.
|
||||
|
||||
e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned.
|
||||
|
||||
Tune as you see fit.
|
||||
|
||||
```py
|
||||
import os
|
||||
|
||||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
|
||||
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4"
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
|
||||
```
|
@ -27,6 +27,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model(
|
||||
return disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def load_shard_file(args):
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model_to_load,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
cpu_offload_folder,
|
||||
cpu_offload_index,
|
||||
is_offloaded_safetensors,
|
||||
keep_in_fp32_regex,
|
||||
unexpected_keys,
|
||||
device_mesh,
|
||||
) = args
|
||||
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
return [], disk_offload_index, cpu_offload_index
|
||||
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
error_msgs = []
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def load_shard_files_with_threadpool(args_list):
|
||||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
||||
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(args_list), num_workers)
|
||||
|
||||
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
|
||||
|
||||
error_msgs = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
|
||||
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
(
|
||||
_error_msgs,
|
||||
disk_offload_index,
|
||||
cpu_offload_index,
|
||||
) = result
|
||||
|
||||
error_msgs += _error_msgs
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
path, name = weights_name.rsplit(".", 1)
|
||||
@ -4973,9 +5084,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
cpu_offload_folder = tempfile.mkdtemp()
|
||||
cpu_offload_index = {}
|
||||
|
||||
# For nice tqdm bars
|
||||
if checkpoint_files is not None and len(checkpoint_files) > 1:
|
||||
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
elif state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
@ -4993,64 +5101,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
|
||||
|
||||
# Prepare and compatabilize arguments for serial and parallel shard loading
|
||||
args_list = [
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model_to_load,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
cpu_offload_folder,
|
||||
cpu_offload_index,
|
||||
is_offloaded_safetensors,
|
||||
keep_in_fp32_regex,
|
||||
unexpected_keys,
|
||||
device_mesh,
|
||||
)
|
||||
for shard_file in checkpoint_files
|
||||
]
|
||||
|
||||
error_msgs = []
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
if (
|
||||
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
and not is_deepspeed_zero3_enabled()
|
||||
):
|
||||
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
|
||||
error_msgs += _error_msgs
|
||||
else:
|
||||
if len(args_list) > 1:
|
||||
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||
del state_dict
|
||||
for args in args_list:
|
||||
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
|
||||
error_msgs += _error_msgs
|
||||
|
||||
# Adjust offloaded weights name and save if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||
|
@ -297,6 +297,27 @@ if is_torch_available():
|
||||
hub.TRANSFORMERS_CACHE = transformers_cache
|
||||
|
||||
|
||||
# Need to be serializable, which means they cannot be in a test class method
|
||||
class TestGammaBetaNorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = torch.nn.Parameter(torch.ones(1))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self):
|
||||
return self.gamma.sum() + self.beta.sum()
|
||||
|
||||
|
||||
class TestModelGammaBeta(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.LayerNorm = TestGammaBetaNorm()
|
||||
self.post_init()
|
||||
|
||||
def forward(self):
|
||||
return self.LayerNorm()
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxBertModel
|
||||
|
||||
@ -1636,24 +1657,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])
|
||||
|
||||
def test_warning_for_beta_gamma_parameters(self):
|
||||
class TestGammaBetaNorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = torch.nn.Parameter(torch.ones(1))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self):
|
||||
return self.gamma.sum() + self.beta.sum()
|
||||
|
||||
class TestModelGammaBeta(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.LayerNorm = TestGammaBetaNorm()
|
||||
self.post_init()
|
||||
|
||||
def forward(self):
|
||||
return self.LayerNorm()
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
config = PretrainedConfig()
|
||||
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
|
||||
|
Loading…
Reference in New Issue
Block a user