# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for BERT model.""" import math import torch from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging logger = logging.getLogger(__name__) __all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"] class LRSchedule(object): """ Parent of all LRSchedules here. """ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense def __init__(self, warmup=0.002, t_total=-1, **kw): """ :param warmup: what fraction of t_total steps will be used for linear warmup :param t_total: how many training steps (updates) are planned :param kw: """ super(LRSchedule, self).__init__(**kw) if t_total < 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) warmup = max(warmup, 0.) self.warmup, self.t_total = float(warmup), float(t_total) self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): """ :param step: which of t_total steps we're on :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps :return: learning rate multiplier for current update """ if self.t_total < 0: return 1. progress = step / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: logger.warning( "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." .format(ret, self.__class__.__name__)) self.warned_for_t_total_at_progress = progress # end warning return ret def get_lr_(self, progress): """ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress :return: learning rate multiplier for current update """ return 1. # raise NotImplemented("use subclass") - class WarmupCosineSchedule(LRSchedule): """ Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts. """ warn_t_total = True def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): """ :param warmup: see LRSchedule :param t_total: see LRSchedule :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. :param kw: """ super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) self.cycles = cycles def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): """ Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1). """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) assert(cycles >= 1.) def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) return ret class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): """ Cosine learning rate schedule with linear warmups and linear warmup restarts. The same warmup rate is used for warmup restarts as for initial warmup. The total effective fraction of warmup steps over all cycles is warmup * cycles! """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): assert(warmup * cycles < 1.) warmup = warmup * cycles if warmup >= 0 else warmup super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) def get_lr_(self, progress): progress = progress * self.cycles % 1. if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup ret = 0.5 * (1. + math.cos(math.pi * progress)) return ret class WarmupConstantSchedule(LRSchedule): """ Applies linear warmup. After warmup always returns 1.. """ def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup return 1. class WarmupLinearSchedule(LRSchedule): """ Linear warmup. Linear decay after warmup. """ warn_t_total = True def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup return max((progress - 1.) / (self.warmup - 1.), 0.) SCHEDULES = { None: LRSchedule, "none": LRSchedule, "warmup_cosine": WarmupCosineSchedule, "warmup_constant": WarmupConstantSchedule, "warmup_linear": WarmupLinearSchedule } class BertAdam(Optimizer): """Implements BERT version of Adam algorithm with weight decay fix. Params: lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 schedule: schedule to use for the warmup (see above). Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object. Default: 'warmup_linear' b1: Adams b1. Default: 0.9 b2: Adams b2. Default: 0.999 e: Adams epsilon. Default: 1e-6 weight_decay: Weight decay. Default: 0.01 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 """ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) if not 0.0 <= b2 < 1.0: raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object if not isinstance(schedule, LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: if warmup != -1 or t_total != -1: logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. " "Please specify custom warmup and t_total in LRSchedule object.") defaults = dict(lr=lr, schedule=schedule, b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) def get_lr(self): lr = [] for group in self.param_groups: for p in group['params']: state = self.state[p] if len(state) == 0: return [0] lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) lr.append(lr_scheduled) return lr def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['next_m'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['next_v'] = torch.zeros_like(p.data) next_m, next_v = state['next_m'], state['next_v'] beta1, beta2 = group['b1'], group['b2'] # Add grad clipping if group['max_grad_norm'] > 0: clip_grad_norm_(p, group['max_grad_norm']) # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time next_m.mul_(beta1).add_(1 - beta1, grad) next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) update = next_m / (next_v.sqrt() + group['e']) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) update_with_lr = lr_scheduled * update p.data.add_(-update_with_lr) state['step'] += 1 # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 # No bias correction # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] return loss