mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Check layer types for Optimizer construction (#10598)
* Check layer types for Optimizer construction * Duplicate class
This commit is contained in:
parent
821d518e03
commit
3ced9b3eb9
@ -80,6 +80,7 @@ from .trainer_pt_utils import (
|
||||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
@ -613,14 +614,15 @@ class Trainer:
|
||||
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if self.optimizer is None:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
@ -672,3 +672,19 @@ def save_state(self):
|
||||
|
||||
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
||||
self.state.save_to_json(path)
|
||||
|
||||
|
||||
def get_parameter_names(model, forbidden_layer_types):
|
||||
"""
|
||||
Returns the names of the model parameters that are not inside a forbidden layer.
|
||||
"""
|
||||
result = []
|
||||
for name, child in model.named_children():
|
||||
result += [
|
||||
f"{name}.{n}"
|
||||
for n in get_parameter_names(child, forbidden_layer_types)
|
||||
if not isinstance(child, tuple(forbidden_layer_types))
|
||||
]
|
||||
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
||||
result += list(model._parameters.keys())
|
||||
return result
|
||||
|
@ -193,6 +193,20 @@ if is_torch_available():
|
||||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
class TstLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln1 = torch.nn.LayerNorm(hidden_size)
|
||||
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln2 = torch.nn.LayerNorm(hidden_size)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
|
||||
def forward(self, x):
|
||||
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
|
||||
h = torch.nn.functional.relu(self.linear2(x))
|
||||
return self.ln2(x + h + self.bias)
|
||||
|
||||
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
|
||||
label_names = kwargs.get("label_names", None)
|
||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||
@ -991,6 +1005,18 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# perfect world: fp32_init/2 == fp16_eval
|
||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
def test_no_wd_param_group(self):
|
||||
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
trainer = Trainer(model=model)
|
||||
trainer.create_optimizer_and_scheduler(10)
|
||||
# fmt: off
|
||||
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
|
||||
# fmt: on
|
||||
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
|
||||
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
|
||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
|
@ -30,8 +30,23 @@ if is_torch_available():
|
||||
DistributedTensorGatherer,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
get_parameter_names,
|
||||
)
|
||||
|
||||
class TstLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln1 = torch.nn.LayerNorm(hidden_size)
|
||||
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln2 = torch.nn.LayerNorm(hidden_size)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
|
||||
def forward(self, x):
|
||||
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
|
||||
h = torch.nn.functional.relu(self.linear2(x))
|
||||
return self.ln2(x + h + self.bias)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerUtilsTest(unittest.TestCase):
|
||||
@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||
# The indices should be a permutation of range(100)
|
||||
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
||||
|
||||
def test_get_parameter_names(self):
|
||||
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
get_parameter_names(model, [torch.nn.LayerNorm]),
|
||||
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
||||
)
|
||||
# fmt: on
|
||||
|
Loading…
Reference in New Issue
Block a user