Fix in gather for SM distributed

This commit is contained in:
Sylvain Gugger 2021-09-27 11:57:18 -04:00
parent 367c2ef53b
commit 4e0410e927

View File

@ -162,8 +162,8 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler