fix galore layerwise with frozen params (#29743)

This commit is contained in:
peterjc123 2024-03-20 18:06:52 +08:00 committed by GitHub
parent 8692aa88e2
commit a1a7454107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -385,7 +385,8 @@ def get_scheduler(
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
if param.requires_grad:
param.register_post_accumulate_grad_hook(scheduler_hook)
return LayerWiseDummyScheduler()

View File

@ -1303,7 +1303,8 @@ class Trainer:
optimizer_dict[param].zero_grad()
for param in model.parameters():
param.register_post_accumulate_grad_hook(optimizer_hook)
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})