mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
refactor decay_parameters production into its own function (#26152)
This commit is contained in:
parent
77ed9fa1a9
commit
c63e27012d
@ -951,6 +951,17 @@ class Trainer:
|
||||
optimizer = self.optimizer
|
||||
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
||||
|
||||
def get_decay_parameter_names(self, model) -> List[str]:
|
||||
"""
|
||||
Get all parameter names that weight decay will be applied to
|
||||
|
||||
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
|
||||
apply to those modules since this function only filter out instance of nn.LayerNorm
|
||||
"""
|
||||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
return decay_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
"""
|
||||
Setup the optimizer.
|
||||
@ -961,8 +972,7 @@ class Trainer:
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if self.optimizer is None:
|
||||
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
|
Loading…
Reference in New Issue
Block a user