Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595)

* Fix passing scheduler-specific kwargs through TrainingArguments `lr_scheduler_kwargs`

* Added test for lr_scheduler_kwargs
This commit is contained in:
Charbel Abi Daher 2023-11-28 08:33:45 +01:00 committed by GitHub
parent 0864dd3beb
commit 2ca73e5ee3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 1 deletions

View File

@ -1111,7 +1111,7 @@ class Trainer:
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
**self.args.lr_scheduler_kwargs,
scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
)
self._created_lr_scheduler = True
return self.lr_scheduler

View File

@ -39,6 +39,7 @@ from transformers import (
IntervalStrategy,
PretrainedConfig,
TrainingArguments,
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
logging,
)
@ -643,6 +644,33 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
def test_lr_scheduler_kwargs(self):
# test scheduler kwargs passed via TrainingArguments
train_dataset = RegressionDataset()
model = RegressionModel()
num_steps, num_warmup_steps = 10, 2
extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments
args = TrainingArguments(
"./regression",
lr_scheduler_type="polynomial",
lr_scheduler_kwargs=extra_kwargs,
learning_rate=0.2,
warmup_steps=num_warmup_steps,
)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# Checking that the scheduler was created
self.assertIsNotNone(trainer.lr_scheduler)
# Checking that the correct args were passed
sched1 = trainer.lr_scheduler
sched2 = get_polynomial_decay_schedule_with_warmup(
trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs
)
self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args)
self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords)
def test_reduce_lr_on_plateau_args(self):
# test passed arguments for a custom ReduceLROnPlateau scheduler
train_dataset = RegressionDataset(length=64)