Fix gather for TPU (#13813)

This commit is contained in:
Sylvain Gugger 2021-09-30 11:32:40 -04:00 committed by GitHub
parent 7db2a79b38
commit 269c3d1400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if tensors.ndim == 0:
tensors = tensors[None]
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")