Fix test failure on DeepSpeed (#29444)

* Fix test failure

* use item
This commit is contained in:
Zach Mueller 2024-03-06 07:11:53 -05:00 committed by GitHub
parent 0a5b0516f8
commit 9322576e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2011,7 +2011,10 @@ class Trainer:
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm().item()
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None