refactor decay_parameters production into its own function (#26152)

This commit is contained in:
Shijie Wu 2023-09-18 11:40:11 -04:00 committed by GitHub
parent 77ed9fa1a9
commit c63e27012d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -951,6 +951,17 @@ class Trainer:
optimizer = self.optimizer
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def create_optimizer(self):
"""
Setup the optimizer.
@ -961,8 +972,7 @@ class Trainer:
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [