mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
layernorm_decay_fix (#35927)
* layernorm_decay_fix * W293 fix * ruff format fix * black format * ruff format * erase last layer * add test_get_parameter_names_rmsnorm * rmsnorm fix
This commit is contained in:
parent
2ba040a71f
commit
b1954fd64a
@ -298,8 +298,7 @@ from transformers.trainer_pt_utils import get_parameter_names
|
|||||||
|
|
||||||
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
||||||
|
|
||||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||||
|
@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names
|
|||||||
|
|
||||||
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
||||||
|
|
||||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||||
|
@ -680,8 +680,7 @@ def main():
|
|||||||
# Instantiate custom data collator
|
# Instantiate custom data collator
|
||||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||||
|
|
||||||
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
|
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{
|
{
|
||||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||||
|
@ -1177,13 +1177,13 @@ class Trainer:
|
|||||||
|
|
||||||
def get_decay_parameter_names(self, model) -> List[str]:
|
def get_decay_parameter_names(self, model) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get all parameter names that weight decay will be applied to
|
Get all parameter names that weight decay will be applied to.
|
||||||
|
|
||||||
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
|
This function filters out parameters in two ways:
|
||||||
apply to those modules since this function only filter out instance of nn.LayerNorm
|
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||||
|
2. By parameter name patterns (containing 'bias', 'layernorm', or 'rmsnorm')
|
||||||
"""
|
"""
|
||||||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS, ["bias", "layernorm", "rmsnorm"])
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
return decay_parameters
|
return decay_parameters
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
|
@ -1120,19 +1120,25 @@ def get_model_param_count(model, trainable_only=False):
|
|||||||
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
|
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_names(model, forbidden_layer_types):
|
def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
|
||||||
"""
|
"""
|
||||||
Returns the names of the model parameters that are not inside a forbidden layer.
|
Returns the names of the model parameters that are not inside a forbidden layer.
|
||||||
"""
|
"""
|
||||||
|
if forbidden_layer_names is None:
|
||||||
|
forbidden_layer_names = []
|
||||||
result = []
|
result = []
|
||||||
for name, child in model.named_children():
|
for name, child in model.named_children():
|
||||||
|
child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
|
||||||
result += [
|
result += [
|
||||||
f"{name}.{n}"
|
f"{name}.{n}"
|
||||||
for n in get_parameter_names(child, forbidden_layer_types)
|
for n in child_params
|
||||||
if not isinstance(child, tuple(forbidden_layer_types))
|
if not isinstance(child, tuple(forbidden_layer_types))
|
||||||
|
and not any(forbidden in f"{name}.{n}".lower() for forbidden in forbidden_layer_names)
|
||||||
]
|
]
|
||||||
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
# Add model specific parameters that are not in any child
|
||||||
result += list(model._parameters.keys())
|
result += [
|
||||||
|
k for k in model._parameters.keys() if not any(forbidden in k.lower() for forbidden in forbidden_layer_names)
|
||||||
|
]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -244,6 +244,33 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_get_parameter_names_rmsnorm(self):
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||||
|
|
||||||
|
class ModelWithRMSNorm(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(128, 128)
|
||||||
|
self.rmsnorm = RMSNorm(128)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(128))
|
||||||
|
|
||||||
|
model = ModelWithRMSNorm()
|
||||||
|
# Test both type-based and name-based filtering
|
||||||
|
decay_parameters = get_parameter_names(model, [], ["bias", "rmsnorm"])
|
||||||
|
|
||||||
|
# Parameters that should be in weight decay
|
||||||
|
self.assertIn("linear.weight", decay_parameters)
|
||||||
|
|
||||||
|
# Parameters that should NOT be in weight decay
|
||||||
|
self.assertNotIn("linear.bias", decay_parameters)
|
||||||
|
self.assertNotIn("rmsnorm.weight", decay_parameters)
|
||||||
|
self.assertNotIn("rmsnorm.bias", decay_parameters)
|
||||||
|
self.assertNotIn("bias", decay_parameters)
|
||||||
|
|
||||||
def test_distributed_sampler_with_loop(self):
|
def test_distributed_sampler_with_loop(self):
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
for length in [23, 64, 123]:
|
for length in [23, 64, 123]:
|
||||||
|
Loading…
Reference in New Issue
Block a user