mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[s2s] configure lr_scheduler from command line (#7641)
This commit is contained in:
parent
4a00613c24
commit
06a973fd2a
@ -4,7 +4,7 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from seq2seq_trainer import Seq2SeqTrainer, arg_to_scheduler_choices
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
||||
attention_dropout: Optional[float] = field(
|
||||
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
|
||||
)
|
||||
lr_scheduler: Optional[str] = field(
|
||||
default="linear", metadata={"help": f"Which lr scheduler to use. Selected in {arg_to_scheduler_choices}"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
||||
from transformers import Trainer
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
)
|
||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||
|
||||
|
||||
@ -20,6 +29,16 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
||||
"constant": get_constant_schedule,
|
||||
"constant_w_warmup": get_constant_schedule_with_warmup,
|
||||
}
|
||||
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, config, data_args, *args, **kwargs):
|
||||
@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
|
||||
)
|
||||
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||
else: # ignoring --lr_scheduler
|
||||
logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")
|
||||
|
||||
def _get_lr_scheduler(self, num_training_steps):
|
||||
schedule_func = arg_to_scheduler[self.args.lr_scheduler]
|
||||
if self.args.lr_scheduler == "constant":
|
||||
scheduler = schedule_func(self.optimizer)
|
||||
elif self.args.lr_scheduler == "constant_w_warmup":
|
||||
scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
|
||||
else:
|
||||
scheduler = schedule_func(
|
||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
)
|
||||
return scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
|
Loading…
Reference in New Issue
Block a user