mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595)
* Fix passing scheduler-specific kwargs through TrainingArguments `lr_scheduler_kwargs` * Added test for lr_scheduler_kwargs
This commit is contained in:
parent
0864dd3beb
commit
2ca73e5ee3
@ -1111,7 +1111,7 @@ class Trainer:
|
||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
self._created_lr_scheduler = True
|
||||
return self.lr_scheduler
|
||||
|
@ -39,6 +39,7 @@ from transformers import (
|
||||
IntervalStrategy,
|
||||
PretrainedConfig,
|
||||
TrainingArguments,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
@ -643,6 +644,33 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
def test_lr_scheduler_kwargs(self):
|
||||
# test scheduler kwargs passed via TrainingArguments
|
||||
train_dataset = RegressionDataset()
|
||||
model = RegressionModel()
|
||||
num_steps, num_warmup_steps = 10, 2
|
||||
extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
lr_scheduler_type="polynomial",
|
||||
lr_scheduler_kwargs=extra_kwargs,
|
||||
learning_rate=0.2,
|
||||
warmup_steps=num_warmup_steps,
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# Checking that the scheduler was created
|
||||
self.assertIsNotNone(trainer.lr_scheduler)
|
||||
|
||||
# Checking that the correct args were passed
|
||||
sched1 = trainer.lr_scheduler
|
||||
sched2 = get_polynomial_decay_schedule_with_warmup(
|
||||
trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs
|
||||
)
|
||||
self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args)
|
||||
self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords)
|
||||
|
||||
def test_reduce_lr_on_plateau_args(self):
|
||||
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
|
Loading…
Reference in New Issue
Block a user