Fix gather for SageMaker model parallel

This commit is contained in:
Sylvain Gugger 2021-09-27 13:11:58 -04:00
parent 4e0410e927
commit 1c96500088

View File

@ -1021,6 +1021,7 @@ if is_sagemaker_mp_enabled():
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):