mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Distributed Trainer: 2 little fixes (#7461)
* reset model.config * Update src/transformers/trainer.py * use lower case tensor * Just tensor change
This commit is contained in:
parent
0acd1ffa09
commit
097049b81b
@ -203,7 +203,7 @@ def distributed_broadcast_scalars(
|
|||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
try:
|
try:
|
||||||
tensorized_scalar = torch.Tensor(scalars).cuda()
|
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user