mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add functions to inspect model and optimizer status to trainer.py (#29838)
* add functions to get number of params which require grad, get optimizer group for parameters and get learning rates of param groups to trainer.py * add tests and raise ValueError when optimizer is None * add second layer to test and freeze its weigths * check if torch is available before running tests * use decorator to check if torch is available Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix test indentation Co-authored-by: Zach Mueller <muellerzr@gmail.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
855b95ce34
commit
aac7099c92
@ -1049,6 +1049,36 @@ class Trainer:
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def get_num_trainable_parameters(self):
|
||||
"""
|
||||
Get the number of trainable parameters.
|
||||
"""
|
||||
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
|
||||
def get_learning_rates(self):
|
||||
"""
|
||||
Returns the learning rate of each parameter from self.optimizer.
|
||||
"""
|
||||
if self.optimizer is None:
|
||||
raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None):
|
||||
"""
|
||||
Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
|
||||
|
||||
Args:
|
||||
param (`str` or `torch.nn.parameter.Parameter`, *optional*):
|
||||
The parameter for which optimizer group needs to be returned.
|
||||
"""
|
||||
if self.optimizer is None:
|
||||
raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
|
||||
if param is not None:
|
||||
for group in self.optimizer.param_groups:
|
||||
if param in group["params"]:
|
||||
return group
|
||||
return [group["params"] for group in self.optimizer.param_groups]
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer_cls_and_kwargs(
|
||||
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||
|
@ -3832,3 +3832,41 @@ class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
|
||||
list(HPSearchBackend),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class OptimizerAndModelInspectionTest(unittest.TestCase):
|
||||
def test_get_num_trainable_parameters(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32))
|
||||
# in_features * out_features + bias
|
||||
layer_1 = 128 * 64 + 64
|
||||
layer_2 = 64 * 32 + 32
|
||||
trainer = Trainer(model=model)
|
||||
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2)
|
||||
# Freeze the last layer
|
||||
for param in model[-1].parameters():
|
||||
param.requires_grad = False
|
||||
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1)
|
||||
|
||||
def test_get_learning_rates(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
trainer = Trainer(model=model)
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_learning_rates()
|
||||
trainer.create_optimizer()
|
||||
self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
|
||||
|
||||
def test_get_optimizer_group(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
trainer = Trainer(model=model)
|
||||
# ValueError is raised if optimizer is None
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_optimizer_group()
|
||||
trainer.create_optimizer()
|
||||
# Get groups
|
||||
num_groups = len(trainer.get_optimizer_group())
|
||||
self.assertEqual(num_groups, 2)
|
||||
# Get group of parameter
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
Loading…
Reference in New Issue
Block a user