From dc10f7906aba0f12137d4679757cde8cb05073f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 21 Jan 2025 21:17:49 +0800 Subject: [PATCH] Support adamw_torch_8bit (#34993) * var * more * test --- src/transformers/trainer.py | 14 +++++++++++--- src/transformers/training_args.py | 1 + tests/trainer/test_trainer.py | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8179ee9f530..f45ff46bdd8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1618,7 +1618,10 @@ class Trainer: "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), } ) - elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + elif args.optim in [ + OptimizerNames.ADAMW_TORCH_4BIT, + OptimizerNames.ADAMW_TORCH_8BIT, + ]: if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse( "0.4.0" ): @@ -1631,9 +1634,14 @@ class Trainer: "You need to have `torch>2.4` in order to use torch 4-bit optimizers. " "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly." ) - from torchao.prototype.low_bit_optim import AdamW4bit + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit - optimizer_cls = AdamW4bit + if args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + optimizer_cls = AdamW4bit + elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT: + optimizer_cls = AdamW8bit + else: + raise ValueError("Invalid optimizer") optimizer_kwargs.update(adam_kwargs) elif args.optim in [ OptimizerNames.SCHEDULE_FREE_ADAMW, diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a7b2ba0db3a..00b9c82ec28 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum): ADAFACTOR = "adafactor" ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_TORCH_4BIT = "adamw_torch_4bit" + ADAMW_TORCH_8BIT = "adamw_torch_8bit" ADEMAMIX = "ademamix" SGD = "sgd" ADAGRAD = "adagrad" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0799f690a34..7df721b3f3c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5017,6 +5017,13 @@ if is_torch_available(): default_adam_kwargs, ) ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_8BIT, output_dir="None"), + torchao.prototype.low_bit_optim.AdamW8bit, + default_adam_kwargs, + ) + ) @require_torch