Add a condition for nested_detach (#31855)

fix bug: https://github.com/huggingface/transformers/issues/31852
This commit is contained in:
haikuoxin 2024-07-11 04:37:22 +08:00 committed by GitHub
parent 080e14b24c
commit c54af4c77e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -192,7 +192,7 @@ def nested_detach(tensors):
return type(tensors)(nested_detach(t) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
return tensors.detach()
return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors
def nested_xla_mesh_reduce(tensors, name):