mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix deepspeed with quantization (#37324)
* Update modeling_utils.py * Update modeling_utils.py
This commit is contained in:
parent
debfe904c9
commit
9db31ea585
@ -3719,19 +3719,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
@classmethod
|
||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||
# With deepspeed, we cannot initialize the model on meta device
|
||||
if is_deepspeed_zero3_enabled():
|
||||
init_contexts = [no_init_weights()]
|
||||
# We cannot initialize the model on meta device with deepspeed when not quantized
|
||||
if not is_quantized and not _is_ds_init_called:
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts.extend(
|
||||
[
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
]
|
||||
)
|
||||
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
||||
elif is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
init_contexts.extend([init_empty_weights(), set_quantized_state()])
|
||||
else:
|
||||
init_contexts = [no_init_weights(), init_empty_weights()]
|
||||
|
||||
@ -4800,7 +4795,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
@ -4822,7 +4821,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
|
Loading…
Reference in New Issue
Block a user