mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
FEAT / Trainer: Add adamw 4bit optimizer (#31865)
* add 4bit optimizer
* style
* fix msg
* style
* add qgalore
* Revert "add qgalore"
This reverts commit 25278e805f
.
* style
* version check
This commit is contained in:
parent
6baa6f276a
commit
c42d264549
@ -168,6 +168,7 @@ from .utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchao_available,
|
||||
logging,
|
||||
strtobool,
|
||||
)
|
||||
@ -1451,7 +1452,23 @@ class Trainer:
|
||||
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
|
||||
}
|
||||
)
|
||||
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
|
||||
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
|
||||
"0.4.0"
|
||||
):
|
||||
raise ImportError(
|
||||
"You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
|
||||
"Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
|
||||
)
|
||||
if version.parse(importlib.metadata.version("torch")) < version.parse("2.3"):
|
||||
raise ImportError(
|
||||
"You need to have `torch>=2.3` in order to use torch 4-bit optimizers. "
|
||||
"Install it with `pip install --upgrade torch`"
|
||||
)
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
|
||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||
ADAFACTOR = "adafactor"
|
||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
|
||||
SGD = "sgd"
|
||||
ADAGRAD = "adagrad"
|
||||
ADAMW_BNB = "adamw_bnb_8bit"
|
||||
|
@ -99,6 +99,7 @@ from transformers.utils import (
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_safetensors_available,
|
||||
is_torchao_available,
|
||||
is_torchdistx_available,
|
||||
)
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
@ -4210,6 +4211,16 @@ if is_torch_available():
|
||||
dict(default_adam_kwargs, **default_anyprecision_kwargs),
|
||||
)
|
||||
)
|
||||
if is_torchao_available():
|
||||
import torchao
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"),
|
||||
torchao.prototype.low_bit_optim.AdamW4bit,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user