mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix in gather for SM distributed
This commit is contained in:
parent
367c2ef53b
commit
4e0410e927
@ -162,8 +162,8 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
|
Loading…
Reference in New Issue
Block a user