[s2s] configure lr_scheduler from command line (#7641)

This commit is contained in:
Suraj Patil 2020-10-08 22:36:35 +05:30 committed by GitHub
parent 4a00613c24
commit 06a973fd2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 3 deletions

View File

@ -4,7 +4,7 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_trainer import Seq2SeqTrainer, arg_to_scheduler_choices
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
attention_dropout: Optional[float] = field(
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
)
lr_scheduler: Optional[str] = field(
default="linear", metadata={"help": f"Which lr scheduler to use. Selected in {arg_to_scheduler_choices}"}
)
@dataclass

View File

@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from transformers.trainer_pt_utils import get_tpu_sampler
@ -20,6 +29,16 @@ except ImportError:
logger = logging.getLogger(__name__)
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
"constant": get_constant_schedule,
"constant_w_warmup": get_constant_schedule_with_warmup,
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
else: # ignoring --lr_scheduler
logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")
def _get_lr_scheduler(self, num_training_steps):
schedule_func = arg_to_scheduler[self.args.lr_scheduler]
if self.args.lr_scheduler == "constant":
scheduler = schedule_func(self.optimizer)
elif self.args.lr_scheduler == "constant_w_warmup":
scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
else:
scheduler = schedule_func(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):