make sure lr is not a tensor (#37881)

* make sure lr is not a tensor

* revert change from #37704

* clean up to reduce extra LoC

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Wing Lian 2025-04-30 08:23:39 -04:00 committed by GitHub
parent 7be92f9a94
commit 4eb6acc896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View File

@ -3079,9 +3079,7 @@ class Trainer:
if grad_norm is not None: if grad_norm is not None:
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
if learning_rate is not None: if learning_rate is not None:
logs["learning_rate"] = ( logs["learning_rate"] = learning_rate
learning_rate.item() if isinstance(learning_rate, torch.Tensor) else learning_rate
)
else: else:
logs["learning_rate"] = self._get_learning_rate() logs["learning_rate"] = self._get_learning_rate()

View File

@ -921,8 +921,9 @@ def _get_learning_rate(self):
last_lr = self.optimizer.param_groups[0]["lr"] last_lr = self.optimizer.param_groups[0]["lr"]
else: else:
last_lr = self.lr_scheduler.get_last_lr()[0] last_lr = self.lr_scheduler.get_last_lr()[0]
if torch.is_tensor(last_lr):
last_lr = last_lr.item() if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr return last_lr