Distributed Trainer: 2 little fixes (#7461)

* reset model.config

* Update src/transformers/trainer.py

* use lower case tensor

* Just tensor change
This commit is contained in:
Sam Shleifer 2020-09-30 22:14:14 -04:00 committed by GitHub
parent 0acd1ffa09
commit 097049b81b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -203,7 +203,7 @@ def distributed_broadcast_scalars(
) -> "torch.Tensor":
if is_torch_available():
try:
tensorized_scalar = torch.Tensor(scalars).cuda()
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)