Fix deepspeed loading (#37281)

* Update modeling_utils.py

* Update modeling_utils.py

* fix and remove all imports

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py
This commit is contained in:
Cyril Vallez 2025-04-05 17:05:45 +02:00 committed by GitHub
parent 0ef339ff1b
commit 84aa13dd85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,7 +57,7 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
@ -153,6 +153,10 @@ if is_safetensors_available():
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
if is_deepspeed_available():
import deepspeed
logger = logging.get_logger(__name__)
@ -2021,8 +2025,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
@ -2662,8 +2664,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
@ -2694,8 +2694,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
@ -2784,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
@ -2830,8 +2826,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
self._init_added_embeddings_weights_with_mean(
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
@ -2847,8 +2841,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
@ -2859,8 +2851,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight = new_embeddings.weight
@ -2918,8 +2908,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
@ -2970,8 +2958,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_lm_head.weight]
if has_new_lm_head_bias:
params += [old_lm_head.bias]
@ -2992,8 +2978,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._copy_lm_head_original_to_resized(
@ -3738,14 +3722,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return super().float(*args)
@classmethod
def get_init_context(
cls: Type[SpecificPreTrainedModelType],
is_quantized=None,
_is_ds_init_called=None,
):
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
@ -4644,6 +4622,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
# Useful flags
is_quantized = hf_quantizer is not None
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
@ -4805,15 +4787,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
# Warmup cuda to load the weights much faster on devices
if device_map is not None: # and hf_quantizer is None:
if device_map is not None:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
error_msgs = []
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
@ -4821,7 +4799,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue
map_location = "cpu"
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
map_location = "meta"
elif (
device_map is not None
@ -5267,8 +5245,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
not_initialized_submodules = dict(self.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(