[trainer: distributed_concat] ensure all_gather's inputs are contiguous (#20951)

[trainer: distributed_concat] ensure all_gather's input are contiguous
This commit is contained in:
Stas Bekman 2022-12-30 21:55:12 -08:00 committed by GitHub
parent 17292440c0
commit 9e6da0a7ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -189,7 +189,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
tensor = atleast_1d(tensor)
tensor = atleast_1d(tensor).contiguous()
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)