layernorm_decay_fix (#35927)

* layernorm_decay_fix

* W293 fix

* ruff format fix

* black format

* ruff format

* erase last layer

* add test_get_parameter_names_rmsnorm

* rmsnorm fix
This commit is contained in:
Ryoo Kwangrok 2025-02-04 19:01:49 +09:00 committed by GitHub
parent 2ba040a71f
commit b1954fd64a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 15 deletions

View File

@ -298,8 +298,7 @@ from transformers.trainer_pt_utils import get_parameter_names
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],

View File

@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],

View File

@ -680,8 +680,7 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],

View File

@ -1177,13 +1177,13 @@ class Trainer:
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Get all parameter names that weight decay will be applied to.
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', 'layernorm', or 'rmsnorm')
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS, ["bias", "layernorm", "rmsnorm"])
return decay_parameters
def create_optimizer(self):

View File

@ -1120,19 +1120,25 @@ def get_model_param_count(model, trainable_only=False):
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
def get_parameter_names(model, forbidden_layer_types):
def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
if forbidden_layer_names is None:
forbidden_layer_names = []
result = []
for name, child in model.named_children():
child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
for n in child_params
if not isinstance(child, tuple(forbidden_layer_types))
and not any(forbidden in f"{name}.{n}".lower() for forbidden in forbidden_layer_names)
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
# Add model specific parameters that are not in any child
result += [
k for k in model._parameters.keys() if not any(forbidden in k.lower() for forbidden in forbidden_layer_names)
]
return result

View File

@ -244,6 +244,33 @@ class TrainerUtilsTest(unittest.TestCase):
)
# fmt: on
def test_get_parameter_names_rmsnorm(self):
class RMSNorm(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
class ModelWithRMSNorm(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
self.rmsnorm = RMSNorm(128)
self.bias = nn.Parameter(torch.zeros(128))
model = ModelWithRMSNorm()
# Test both type-based and name-based filtering
decay_parameters = get_parameter_names(model, [], ["bias", "rmsnorm"])
# Parameters that should be in weight decay
self.assertIn("linear.weight", decay_parameters)
# Parameters that should NOT be in weight decay
self.assertNotIn("linear.bias", decay_parameters)
self.assertNotIn("rmsnorm.weight", decay_parameters)
self.assertNotIn("rmsnorm.bias", decay_parameters)
self.assertNotIn("bias", decay_parameters)
def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]: