mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
add RAdamScheduleFree optimizer (#35313)
* add RAdamScheduleFree optimizer * revert schedulefree version to the minimum requirement * refine is_schedulefree_available so that it can take min_version * refine documents --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
f5fff672db
commit
377d8e2b9c
@ -542,13 +542,17 @@ trainer = Trainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||
This script demonstrates how to fine-tune the [google/gemma-2b](https://huggingface.co/google/gemma-2b) model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||
|
||||
### Schedule Free Optimizer
|
||||
### Schedule-Free Optimizer
|
||||
|
||||
The Schedule-Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
|
||||
Supported optimizers for Schedule-Free are `schedule_free_radam`, `schedule_free_adamw` and `schedule_free_sgd`. First install schedulefree from pypi `pip install schedulefree`.
|
||||
|
||||
The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
|
||||
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
|
||||
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
|
||||
Additionally, neither `warmup_steps` nor `warmup_ratio` parameters are required when using `schedule_free_radam`.
|
||||
|
||||
By default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees.
|
||||
|
||||
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:
|
||||
|
||||
@ -564,7 +568,8 @@ args = TrainingArguments(
|
||||
output_dir="./test-schedulefree",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
optim="schedule_free_adamw",
|
||||
optim="schedule_free_radam",
|
||||
lr_scheduler_type="constant",
|
||||
gradient_checkpointing=True,
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
|
@ -1644,28 +1644,44 @@ class Trainer:
|
||||
raise ValueError("Invalid optimizer")
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif args.optim in [
|
||||
OptimizerNames.SCHEDULE_FREE_RADAM,
|
||||
OptimizerNames.SCHEDULE_FREE_ADAMW,
|
||||
OptimizerNames.SCHEDULE_FREE_SGD,
|
||||
]:
|
||||
if not is_schedulefree_available():
|
||||
raise ImportError(
|
||||
"You need to install `schedulefree` in order to use schedulefree optimizers"
|
||||
" install it with `pip install schedulefree`"
|
||||
"You need to install `schedulefree` in order to use schedulefree optimizers. "
|
||||
"Install it with `pip install schedulefree.`"
|
||||
)
|
||||
if not is_accelerate_available("0.30.0"):
|
||||
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
|
||||
from schedulefree import AdamWScheduleFree, SGDScheduleFree
|
||||
|
||||
additional_optim_kwargs = {}
|
||||
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
|
||||
require_warmup = True
|
||||
|
||||
if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
|
||||
if not is_schedulefree_available("1.4.0"):
|
||||
raise ImportError(
|
||||
"You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
|
||||
"Install it with `pip install schedulefree.`"
|
||||
)
|
||||
from schedulefree import RAdamScheduleFree
|
||||
|
||||
optimizer_cls = RAdamScheduleFree
|
||||
additional_optim_kwargs = adam_kwargs
|
||||
require_warmup = False
|
||||
elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
|
||||
optimizer_cls = AdamWScheduleFree
|
||||
additional_optim_kwargs = adam_kwargs
|
||||
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
|
||||
optimizer_cls = SGDScheduleFree
|
||||
else:
|
||||
raise ValueError("Invalid schedulefree optimizer")
|
||||
|
||||
additional_optim_kwargs["weight_decay"] = args.weight_decay
|
||||
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
|
||||
if require_warmup:
|
||||
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
|
||||
additional_optim_kwargs.update(
|
||||
{
|
||||
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
|
||||
|
@ -182,6 +182,7 @@ class OptimizerNames(ExplicitEnum):
|
||||
LOMO = "lomo"
|
||||
ADALOMO = "adalomo"
|
||||
GROKADAMW = "grokadamw"
|
||||
SCHEDULE_FREE_RADAM = "schedule_free_radam"
|
||||
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
|
||||
SCHEDULE_FREE_SGD = "schedule_free_sgd"
|
||||
|
||||
|
@ -89,6 +89,7 @@ FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
||||
|
||||
ACCELERATE_MIN_VERSION = "0.26.0"
|
||||
SCHEDULEFREE_MIN_VERSION = "1.2.6"
|
||||
FSDP_MIN_VERSION = "1.12.0"
|
||||
GGUF_MIN_VERSION = "0.10.0"
|
||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
@ -108,7 +109,7 @@ _fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
_lomo_available = _is_package_available("lomo_optim")
|
||||
_grokadamw_available = _is_package_available("grokadamw")
|
||||
_schedulefree_available = _is_package_available("schedulefree")
|
||||
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||
@ -410,8 +411,8 @@ def is_grokadamw_available():
|
||||
return _grokadamw_available
|
||||
|
||||
|
||||
def is_schedulefree_available():
|
||||
return _schedulefree_available
|
||||
def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION):
|
||||
return _schedulefree_available and version.parse(_schedulefree_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_pyctcdecode_available():
|
||||
|
@ -1865,14 +1865,38 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
lr_scheduler_type="constant",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_schedulefree
|
||||
@require_torch_gpu
|
||||
def test_schedulefree_radam(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
lr_scheduler_type="constant",
|
||||
optim="schedule_free_radam",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
Loading…
Reference in New Issue
Block a user