mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Distributed Trainer: 2 little fixes (#7461)
* reset model.config * Update src/transformers/trainer.py * use lower case tensor * Just tensor change
This commit is contained in:
parent
0acd1ffa09
commit
097049b81b
@ -203,7 +203,7 @@ def distributed_broadcast_scalars(
|
||||
) -> "torch.Tensor":
|
||||
if is_torch_available():
|
||||
try:
|
||||
tensorized_scalar = torch.Tensor(scalars).cuda()
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user