mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#29587)
* fsdp+qlora related changes * fixes * Update quantization_config.py * support fsdp+qlora and dsz3+qlora * Update quantization_config.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * handle fsdp+qlora and dsz3+qlora correctly while model loading * fix param count * quality * fsdp related changes * fsdp changes only when using LoRA/QLoRA * add accelerate version check * refactor, update min accelerate version and add tests 1. Update minimum accelerate version to 0.26.0 2. Clean the trainer wrt accelerate version checks 3. FSDP refactor and test for fsdp config 4. use `itemsize` instead of `dtype2bytes` dict * fix test * Address comments Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix the conditional flag * fix conditional flag * address comments Co-Authored-By: Zach Mueller <7831895+muellerzr@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Zach Mueller <7831895+muellerzr@users.noreply.github.com>
This commit is contained in:
parent
d3801aae2e
commit
350c5d1566
@ -1,6 +1,7 @@
|
||||
import importlib.metadata
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -179,6 +180,11 @@ def _replace_with_bnb_linear(
|
||||
):
|
||||
pass
|
||||
else:
|
||||
extra_kwargs = (
|
||||
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
|
||||
if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
|
||||
else {}
|
||||
)
|
||||
model._modules[name] = bnb.nn.Linear4bit(
|
||||
in_features,
|
||||
out_features,
|
||||
@ -186,6 +192,7 @@ def _replace_with_bnb_linear(
|
||||
quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||
**extra_kwargs,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
|
@ -54,6 +54,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
prune_linear_layer,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
from .safetensors_conversion import auto_conversion
|
||||
from .utils import (
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
@ -496,7 +497,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@ -512,8 +513,9 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
return safe_load_file(checkpoint_file)
|
||||
try:
|
||||
if (
|
||||
is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
|
||||
) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
|
||||
(is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
|
||||
or (is_fsdp_enabled() and not is_local_dist_rank_0())
|
||||
) and not is_quantized:
|
||||
map_location = "meta"
|
||||
else:
|
||||
map_location = "cpu"
|
||||
@ -718,6 +720,7 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
is_quantized = hf_quantizer is not None
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
@ -797,7 +800,7 @@ def _load_state_dict_into_meta_model(
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
elif (
|
||||
hf_quantizer is None
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
|
||||
):
|
||||
@ -805,6 +808,14 @@ def _load_state_dict_into_meta_model(
|
||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
value = getattr(module, tensor_name)
|
||||
value = type(value)(value.data.to("cpu"), **value.__dict__)
|
||||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
@ -1070,7 +1081,9 @@ class ModuleUtilsMixin:
|
||||
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
|
||||
# used for the 4bit quantization (uint8 tensors are stored)
|
||||
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
|
||||
total_numel.append(param.numel() * 2)
|
||||
total_numel.append(
|
||||
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize
|
||||
)
|
||||
else:
|
||||
total_numel.append(param.numel())
|
||||
|
||||
@ -1805,10 +1818,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
|
||||
new_embeddings.requires_grad_(old_embeddings_requires_grad)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
|
||||
# Update new_num_tokens with the actual size of new_embeddings
|
||||
if pad_to_multiple_of is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
|
||||
@ -1882,7 +1896,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
||||
@ -1921,7 +1936,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# numbers of tokens to copy
|
||||
n = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
@ -1958,7 +1973,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if new_num_tokens is None:
|
||||
return old_lm_head
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
|
||||
@ -2000,7 +2016,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
||||
@ -3036,6 +3052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
@ -3365,7 +3382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Instantiate model.
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
@ -3564,7 +3581,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
@ -3610,6 +3628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
is_safetensors = False
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
archive_file = (
|
||||
@ -3735,7 +3754,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if param.device == torch.device("meta"):
|
||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||
if (
|
||||
hf_quantizer is None
|
||||
not is_quantized
|
||||
or getattr(hf_quantizer, "requires_parameters_quantization", False)
|
||||
or not hf_quantizer.check_quantized_param(
|
||||
model, param_value=value, param_name=key, state_dict={}
|
||||
@ -3765,7 +3784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
not_initialized_submodules = dict(model.named_modules())
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
not_initialized_parameters = list(
|
||||
@ -3909,7 +3928,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
state_dict = load_state_dict(shard_file)
|
||||
state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
@ -3922,15 +3941,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
if low_cpu_mem_usage:
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
for key, param in model_to_load.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
if hf_quantizer is None:
|
||||
set_module_tensor_to_device(
|
||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(model, param, key, "cpu", state_dict)
|
||||
set_module_tensor_to_device(
|
||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
|
@ -1776,6 +1776,7 @@ class Trainer:
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
@ -4156,3 +4157,20 @@ class Trainer:
|
||||
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
||||
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
||||
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
|
||||
|
||||
def _fsdp_qlora_plugin_updates(self):
|
||||
if self.is_fsdp_enabled and _is_peft_model(self.model):
|
||||
from peft import LoraConfig
|
||||
from peft.utils.other import fsdp_auto_wrap_policy
|
||||
|
||||
if isinstance(self.model.active_peft_config, LoraConfig):
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
|
||||
if (
|
||||
getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
||||
and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
|
||||
and version.parse(accelerate_version) > version.parse("0.27.0")
|
||||
):
|
||||
fsdp_plugin.set_mixed_precision(
|
||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||
)
|
||||
|
@ -1721,8 +1721,10 @@ class TrainingArguments:
|
||||
for fsdp_option in self.fsdp:
|
||||
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||
# set environment variable for FSDP sharding strategy
|
||||
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
|
||||
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
|
||||
os.environ[f"{prefix}SHARDING_STRATEGY"] = (
|
||||
str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||
if is_accelerate_available("0.26.0")
|
||||
else fsdp_option.upper()
|
||||
)
|
||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
|
||||
|
@ -225,6 +225,8 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
|
||||
This flag is used for nested quantization where the quantization constants from the first quantization are
|
||||
quantized again.
|
||||
bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
|
||||
This sets the storage type to pack the quanitzed 4-bit prarams.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
"""
|
||||
@ -240,6 +242,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
bnb_4bit_compute_dtype=None,
|
||||
bnb_4bit_quant_type="fp4",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
bnb_4bit_quant_storage=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.BITS_AND_BYTES
|
||||
@ -265,6 +268,15 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
else:
|
||||
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
||||
|
||||
if bnb_4bit_quant_storage is None:
|
||||
self.bnb_4bit_quant_storage = torch.uint8
|
||||
elif isinstance(bnb_4bit_quant_storage, str):
|
||||
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
|
||||
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
|
||||
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
||||
else:
|
||||
raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
|
||||
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
@ -345,6 +357,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
|
||||
output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
|
||||
output["load_in_4bit"] = self.load_in_4bit
|
||||
output["load_in_8bit"] = self.load_in_8bit
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from parameterized import parameterized
|
||||
@ -171,6 +172,44 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(v, self.fsdp_config[k])
|
||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||
|
||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||
def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
fsdp_config = deepcopy(self.fsdp_config)
|
||||
del fsdp_config["min_num_params"]
|
||||
fsdp_config["transformer_layer_cls_to_wrap"] = "BertLayer"
|
||||
kwargs = {
|
||||
"output_dir": output_dir,
|
||||
"train_len": 128,
|
||||
"save_steps": 5,
|
||||
"learning_rate": 0.1,
|
||||
"fsdp": f"{sharding_strategy} offload auto_wrap",
|
||||
"fsdp_config": fsdp_config,
|
||||
}
|
||||
kwargs[dtype] = True
|
||||
prefix = "FSDP_"
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
self.assertEqual(trainer.args.fsdp[0], sharding_strategy)
|
||||
self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD)
|
||||
self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP)
|
||||
fsdp_sharding_strategy = (
|
||||
str(FSDP_SHARDING_STRATEGY.index(sharding_strategy.upper()) + 1)
|
||||
if is_accelerate_available("0.26.0")
|
||||
else sharding_strategy.upper()
|
||||
)
|
||||
self.assertEqual(os.environ[f"{prefix}SHARDING_STRATEGY"], fsdp_sharding_strategy)
|
||||
self.assertEqual(os.environ[f"{prefix}OFFLOAD_PARAMS"], "true")
|
||||
self.assertEqual(os.environ[f"{prefix}AUTO_WRAP_POLICY"], "TRANSFORMER_BASED_WRAP")
|
||||
self.assertEqual(
|
||||
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"])
|
||||
)
|
||||
self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"].upper())
|
||||
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
|
||||
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
|
||||
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
|
||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||
|
||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user