Check layer types for Optimizer construction (#10598)

* Check layer types for Optimizer construction

* Duplicate class
This commit is contained in:
Sylvain Gugger 2021-03-08 16:40:11 -05:00 committed by GitHub
parent 821d518e03
commit 3ced9b3eb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 3 deletions

View File

@ -80,6 +80,7 @@ from .trainer_pt_utils import (
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
@ -613,14 +614,15 @@ class Trainer:
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]

View File

@ -672,3 +672,19 @@ def save_state(self):
path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result

View File

@ -193,6 +193,20 @@ if is_torch_available():
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
@ -991,6 +1005,18 @@ class TrainerIntegrationTest(unittest.TestCase):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
def test_no_wd_param_group(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model)
trainer.create_optimizer_and_scheduler(10)
# fmt: off
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
# fmt: on
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
@require_torch
@require_optuna

View File

@ -30,8 +30,23 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
get_parameter_names,
)
class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
@require_torch
class TrainerUtilsTest(unittest.TestCase):
@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(lengths[indices_process_0[0]], 50)
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
def test_get_parameter_names(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
# fmt: off
self.assertEqual(
get_parameter_names(model, [torch.nn.LayerNorm]),
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
)
# fmt: on