mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add a condition for nested_detach (#31855)
fix bug: https://github.com/huggingface/transformers/issues/31852
This commit is contained in:
parent
080e14b24c
commit
c54af4c77e
@ -192,7 +192,7 @@ def nested_detach(tensors):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
|
||||
return tensors.detach()
|
||||
return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors
|
||||
|
||||
|
||||
def nested_xla_mesh_reduce(tensors, name):
|
||||
|
Loading…
Reference in New Issue
Block a user