mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595)
* Fix passing scheduler-specific kwargs through TrainingArguments `lr_scheduler_kwargs` * Added test for lr_scheduler_kwargs
This commit is contained in:
parent
0864dd3beb
commit
2ca73e5ee3
@ -1111,7 +1111,7 @@ class Trainer:
|
|||||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
optimizer=self.optimizer if optimizer is None else optimizer,
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=num_training_steps,
|
num_training_steps=num_training_steps,
|
||||||
**self.args.lr_scheduler_kwargs,
|
scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
self._created_lr_scheduler = True
|
self._created_lr_scheduler = True
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
@ -39,6 +39,7 @@ from transformers import (
|
|||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@ -643,6 +644,33 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||||
|
|
||||||
|
def test_lr_scheduler_kwargs(self):
|
||||||
|
# test scheduler kwargs passed via TrainingArguments
|
||||||
|
train_dataset = RegressionDataset()
|
||||||
|
model = RegressionModel()
|
||||||
|
num_steps, num_warmup_steps = 10, 2
|
||||||
|
extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./regression",
|
||||||
|
lr_scheduler_type="polynomial",
|
||||||
|
lr_scheduler_kwargs=extra_kwargs,
|
||||||
|
learning_rate=0.2,
|
||||||
|
warmup_steps=num_warmup_steps,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||||
|
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||||
|
|
||||||
|
# Checking that the scheduler was created
|
||||||
|
self.assertIsNotNone(trainer.lr_scheduler)
|
||||||
|
|
||||||
|
# Checking that the correct args were passed
|
||||||
|
sched1 = trainer.lr_scheduler
|
||||||
|
sched2 = get_polynomial_decay_schedule_with_warmup(
|
||||||
|
trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs
|
||||||
|
)
|
||||||
|
self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args)
|
||||||
|
self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords)
|
||||||
|
|
||||||
def test_reduce_lr_on_plateau_args(self):
|
def test_reduce_lr_on_plateau_args(self):
|
||||||
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||||
train_dataset = RegressionDataset(length=64)
|
train_dataset = RegressionDataset(length=64)
|
||||||
|
Loading…
Reference in New Issue
Block a user