Support adamw_torch_8bit (#34993)

* var

* more

* test
This commit is contained in:
fzyzcjy 2025-01-21 21:17:49 +08:00 committed by GitHub
parent f82b19cb6f
commit dc10f7906a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 3 deletions

View File

@ -1618,7 +1618,10 @@ class Trainer:
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
elif args.optim in [
OptimizerNames.ADAMW_TORCH_4BIT,
OptimizerNames.ADAMW_TORCH_8BIT,
]:
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
"0.4.0"
):
@ -1631,9 +1634,14 @@ class Trainer:
"You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
"Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
)
from torchao.prototype.low_bit_optim import AdamW4bit
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_cls = AdamW4bit
if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
optimizer_cls = AdamW4bit
elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT:
optimizer_cls = AdamW8bit
else:
raise ValueError("Invalid optimizer")
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW,

View File

@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
ADAMW_TORCH_8BIT = "adamw_torch_8bit"
ADEMAMIX = "ademamix"
SGD = "sgd"
ADAGRAD = "adagrad"

View File

@ -5017,6 +5017,13 @@ if is_torch_available():
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_8BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW8bit,
default_adam_kwargs,
)
)
@require_torch