mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[trainer: distributed_concat
] ensure all_gather
's inputs are contiguous (#20951)
[trainer: distributed_concat] ensure all_gather's input are contiguous
This commit is contained in:
parent
17292440c0
commit
9e6da0a7ed
@ -189,7 +189,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
|
||||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
tensor = atleast_1d(tensor)
|
||||
tensor = atleast_1d(tensor).contiguous()
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user