diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 63a1ddfc33c..334a8e4d959 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -203,7 +203,7 @@ def distributed_broadcast_scalars( ) -> "torch.Tensor": if is_torch_available(): try: - tensorized_scalar = torch.Tensor(scalars).cuda() + tensorized_scalar = torch.tensor(scalars).cuda() output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(output_tensors, tensorized_scalar) concat = torch.cat(output_tensors, dim=0)