mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix gather for TPU (#13813)
This commit is contained in:
parent
7db2a79b38
commit
269c3d1400
@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name):
|
||||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
if tensors.ndim == 0:
|
||||
tensors = tensors[None]
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||
|
Loading…
Reference in New Issue
Block a user