mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add qgalore
This commit is contained in:
parent
3dca6a209e
commit
25278e805f
@ -155,6 +155,7 @@ from .utils import (
|
||||
is_ipex_available,
|
||||
is_lomo_available,
|
||||
is_peft_available,
|
||||
is_q_galore_torch_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
@ -1288,6 +1289,132 @@ class Trainer:
|
||||
optimizer_cls = torch.optim.Adagrad
|
||||
elif args.optim == OptimizerNames.RMSPROP:
|
||||
optimizer_cls = torch.optim.RMSprop
|
||||
elif args.optim in [OptimizerNames.QGALORE_ADAMW_8BIT, OptimizerNames.QGALORE_ADAMW_8BIT_LAYERWISE]:
|
||||
if not is_q_galore_torch_available():
|
||||
raise ImportError(
|
||||
"You need to install `q-galore-torch` in order to use GaLore optimizers"
|
||||
" install it with `pip install qgalore"
|
||||
)
|
||||
from q_galore_torch import QGaLoreAdamW8bit
|
||||
|
||||
is_layerwise = args.optim.lower().endswith("layerwise")
|
||||
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
# TODO: check if this is True
|
||||
raise NotImplementedError("Layer-wise QGaLore does not support DDP at this time")
|
||||
|
||||
optimizer_cls = QGaLoreAdamW8bit
|
||||
|
||||
if args.optim_target_modules is None:
|
||||
raise ValueError(
|
||||
"You need to define a `optim_target_modules` in order to properly use QGaLore optimizers"
|
||||
)
|
||||
if args.optim_target_modules is None:
|
||||
raise ValueError(
|
||||
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
|
||||
)
|
||||
|
||||
if not isinstance(args.optim_target_modules, (list, str)):
|
||||
raise ValueError(
|
||||
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
|
||||
|
||||
logger.warning(
|
||||
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
|
||||
)
|
||||
|
||||
all_linear = (
|
||||
isinstance(args.optim_target_modules, str)
|
||||
and args.optim_target_modules.replace("_", "-") == "all-linear"
|
||||
)
|
||||
|
||||
galore_params = []
|
||||
galore_params_names = []
|
||||
for module_name, module in model.named_modules():
|
||||
target_module_exists, is_regex = check_target_module_exists(
|
||||
args.optim_target_modules, module_name, return_is_regex=True
|
||||
)
|
||||
|
||||
if not isinstance(module, nn.Linear):
|
||||
# Warn in case we match but it's not a linear layer
|
||||
if target_module_exists and not is_regex:
|
||||
logger.warning(
|
||||
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
if not target_module_exists and not all_linear:
|
||||
continue
|
||||
|
||||
galore_params.append(module.weight)
|
||||
galore_params_names.append(module_name + ".weight")
|
||||
|
||||
if len(galore_params) == 0:
|
||||
raise ValueError(
|
||||
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
|
||||
)
|
||||
|
||||
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
|
||||
|
||||
# The default args are from the official repository: https://github.com/VITA-Group/Q-GaLore
|
||||
galore_optim_kwargs = {
|
||||
"rank": int(optim_args.pop("rank", 256)),
|
||||
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
|
||||
"scale": float(optim_args.pop("scale", 0.25)),
|
||||
"proj_type": optim_args.pop("proj_type", "std"),
|
||||
"quant": optim_args.pop("quant", True),
|
||||
"quant_n_bit": optim_args.pop("quant_n_bit", 4),
|
||||
"quant_group_size": optim_args.pop("quant_group_size", 256),
|
||||
"cos_threshold": optim_args.pop("cos_threshold", 0.4),
|
||||
"gamma_proj": optim_args.pop("gamma_proj", 2),
|
||||
"queue_size": optim_args.pop("queue_size", 5),
|
||||
}
|
||||
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{"params": galore_params, **galore_optim_kwargs},
|
||||
]
|
||||
|
||||
if is_layerwise:
|
||||
# For layer-wise optimizers, the optimization step is done through post accumulation
|
||||
# gradient hooks. The trick is to first attach these hooks to the model parameters then
|
||||
# create a dummy optimizer that will perform no-ops in the Trainer.
|
||||
# See the original implementation or the nice implementation from @hiyouga
|
||||
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||
if args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Layerwise QGaLoRE optimizer do not support gradient accumulation !")
|
||||
|
||||
optimizer_dict = {}
|
||||
for param in non_galore_params:
|
||||
if param.requires_grad:
|
||||
param_groups = [{"params": [param]}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
# TODO: in the original repo, they multiply update_proj_gap param by 2, to check
|
||||
for param in galore_params:
|
||||
param_groups = [{"params": [param], **galore_optim_kwargs}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
|
||||
def optimizer_hook(param):
|
||||
if (not hasattr(param, "float_grad")) and param.grad is None:
|
||||
return
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
|
||||
# TODO: strange, we are not applying on every param here compared to galore
|
||||
for param in model.parameters():
|
||||
if id(param) in id_galore_params or param.requires_grad:
|
||||
setattr(param, "backward_hook", optimizer_hook)
|
||||
|
||||
optimizer_cls = LayerWiseDummyOptimizer
|
||||
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
|
||||
|
||||
optimizer_kwargs.update({"params": param_groups})
|
||||
|
||||
elif args.optim in [
|
||||
OptimizerNames.GALORE_ADAMW,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT,
|
||||
|
@ -174,6 +174,8 @@ class OptimizerNames(ExplicitEnum):
|
||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||
QGALORE_ADAMW_8BIT = "qgalore_adamw_8bit"
|
||||
QGALORE_ADAMW_8BIT_LAYERWISE = "qgalore_adamw_8bit_layerwise"
|
||||
LOMO = "lomo"
|
||||
ADALOMO = "adalomo"
|
||||
|
||||
|
@ -164,6 +164,7 @@ from .import_utils import (
|
||||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_q_galore_torch_available,
|
||||
is_quanto_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
|
@ -99,6 +99,7 @@ _av_available = importlib.util.find_spec("av") is not None
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_eetq_available = _is_package_available("eetq")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
_q_galore_torch_available = _is_package_available("q_galore_torch")
|
||||
_lomo_available = _is_package_available("lomo_optim")
|
||||
_torchao_available = _is_package_available("torchao")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
@ -346,6 +347,10 @@ def is_galore_torch_available():
|
||||
return _galore_torch_available
|
||||
|
||||
|
||||
def is_q_galore_torch_available():
|
||||
return _q_galore_torch_available
|
||||
|
||||
|
||||
def is_lomo_available():
|
||||
return _lomo_available
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user