mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix gather for SageMaker model parallel
This commit is contained in:
parent
4e0410e927
commit
1c96500088
@ -1021,6 +1021,7 @@ if is_sagemaker_mp_enabled():
|
||||
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||
)
|
||||
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
|
||||
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||
|
||||
def smp_nested_concat(tensor):
|
||||
|
Loading…
Reference in New Issue
Block a user