Introduce warmup_ratio training argument (#10229)

Introduce warmup_ratio training argument in both
TrainingArguments and TFTrainingArguments classes (#6673)
This commit is contained in:
Tanmay Garg 2021-02-18 22:53:33 +05:30 committed by GitHub
parent 2acae50a0c
commit d7f38c5d1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 4 deletions

View File

@ -615,10 +615,16 @@ class Trainer:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
)

View File

@ -218,10 +218,16 @@ class TFTrainer:
TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
"""
if not self.optimizer and not self.lr_scheduler:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.optimizer, self.lr_scheduler = create_optimizer(
self.args.learning_rate,
num_training_steps,
self.args.warmup_steps,
warmup_steps,
adam_beta1=self.args.adam_beta1,
adam_beta2=self.args.adam_beta2,
adam_epsilon=self.args.adam_epsilon,

View File

@ -131,8 +131,11 @@ class TrainingArguments:
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
values.
warmup_ratio (:obj:`float`, `optional`, defaults to 0.0):
Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`.
warmup_steps (:obj:`int`, `optional`, defaults to 0):
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
:obj:`warmup_ratio`.
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
@ -324,6 +327,9 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
warmup_ratio: float = field(
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
@ -495,6 +501,13 @@ class TrainingArguments:
elif not isinstance(self.report_to, list):
self.report_to = [self.report_to]
if self.warmup_ratio < 0 or self.warmup_ratio > 1:
raise ValueError("warmup_ratio must lie in range [0,1]")
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
logger.info(
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
)
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)

View File

@ -94,8 +94,11 @@ class TFTrainingArguments(TrainingArguments):
max_steps (:obj:`int`, `optional`, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides
:obj:`num_train_epochs`.
warmup_ratio (:obj:`float`, `optional`, defaults to 0.0):
Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`.
warmup_steps (:obj:`int`, `optional`, defaults to 0):
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
:obj:`warmup_ratio`.
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.