mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
make sure lr is not a tensor (#37881)
* make sure lr is not a tensor * revert change from #37704 * clean up to reduce extra LoC --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
7be92f9a94
commit
4eb6acc896
@ -3079,9 +3079,7 @@ class Trainer:
|
|||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||||
if learning_rate is not None:
|
if learning_rate is not None:
|
||||||
logs["learning_rate"] = (
|
logs["learning_rate"] = learning_rate
|
||||||
learning_rate.item() if isinstance(learning_rate, torch.Tensor) else learning_rate
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logs["learning_rate"] = self._get_learning_rate()
|
logs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
|
@ -921,8 +921,9 @@ def _get_learning_rate(self):
|
|||||||
last_lr = self.optimizer.param_groups[0]["lr"]
|
last_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
else:
|
else:
|
||||||
last_lr = self.lr_scheduler.get_last_lr()[0]
|
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
if torch.is_tensor(last_lr):
|
|
||||||
last_lr = last_lr.item()
|
if torch.is_tensor(last_lr):
|
||||||
|
last_lr = last_lr.item()
|
||||||
return last_lr
|
return last_lr
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user