mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix deepspeed loading (#37281)
* Update modeling_utils.py * Update modeling_utils.py * fix and remove all imports * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py
This commit is contained in:
parent
0ef339ff1b
commit
84aa13dd85
@ -57,7 +57,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
@ -153,6 +153,10 @@ if is_safetensors_available():
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -2021,8 +2025,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||
# and memory copying it on CPU or each GPU first
|
||||
@ -2662,8 +2664,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
|
||||
vocab_size = model_embeds.weight.shape[0]
|
||||
else:
|
||||
@ -2694,8 +2694,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Update new_num_tokens with the actual size of new_embeddings
|
||||
if pad_to_multiple_of is not None:
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
|
||||
new_num_tokens = new_embeddings.weight.shape[0]
|
||||
else:
|
||||
@ -2784,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
else:
|
||||
@ -2830,8 +2826,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
added_num_tokens = new_num_tokens - old_num_tokens
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
||||
self._init_added_embeddings_weights_with_mean(
|
||||
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
|
||||
@ -2847,8 +2841,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
n = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
@ -2859,8 +2851,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# This ensures correct functionality when a Custom Embedding class is passed as input.
|
||||
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
old_embeddings.weight = new_embeddings.weight
|
||||
@ -2918,8 +2908,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
|
||||
old_num_tokens, old_lm_head_dim = (
|
||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||
@ -2970,8 +2958,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
added_num_tokens = new_num_tokens - old_num_tokens
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_lm_head.weight]
|
||||
if has_new_lm_head_bias:
|
||||
params += [old_lm_head.bias]
|
||||
@ -2992,8 +2978,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
self._copy_lm_head_original_to_resized(
|
||||
@ -3738,14 +3722,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return super().float(*args)
|
||||
|
||||
@classmethod
|
||||
def get_init_context(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
is_quantized=None,
|
||||
_is_ds_init_called=None,
|
||||
):
|
||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts = [
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
@ -4644,6 +4622,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
):
|
||||
# Useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
@ -4805,15 +4787,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None: # and hf_quantizer is None:
|
||||
if device_map is not None:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
error_msgs = []
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
@ -4821,7 +4799,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
|
||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
@ -5267,8 +5245,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
not_initialized_submodules = dict(self.named_modules())
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
not_initialized_parameters = list(
|
||||
set(
|
||||
itertools.chain.from_iterable(
|
||||
|
Loading…
Reference in New Issue
Block a user