Extend nested_XXX functions to mappings/dicts. (#19455)

* Extend `nested_XXX` functions to mappings/dicts.

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Style updated file

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Guillem Orellana Trullols 2022-10-11 14:13:21 +02:00 committed by GitHub
parent b722a6be72
commit 335f9bcd34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,7 +104,7 @@ def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
def nested_concat(tensors, new_tensors, padding_index=-100):
"""
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples of tensors.
nested list/tuples/dict of tensors.
"""
assert type(tensors) == type(
new_tensors
@ -113,6 +113,10 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
elif isinstance(tensors, torch.Tensor):
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
elif isinstance(tensors, Mapping):
return type(tensors)(
{k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
)
elif isinstance(tensors, np.ndarray):
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
else:
@ -140,9 +144,12 @@ def find_batch_size(tensors):
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
"Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
if isinstance(tensors, Mapping):
return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
t = tensors.cpu()
if t.dtype == torch.bfloat16:
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
@ -153,9 +160,11 @@ def nested_numpify(tensors):
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
"Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
return tensors.detach()
@ -165,6 +174,11 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if isinstance(tensors, Mapping):
return type(tensors)(
{k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
)
tensors = atleast_1d(tensors)
return xm.mesh_reduce(name, tensors, torch.cat)
else:
@ -335,9 +349,12 @@ def expand_like(arrays, new_seq_length, padding_index=-100):
def nested_truncate(tensors, limit):
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
"Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_truncate(t, limit) for t in tensors)
if isinstance(tensors, Mapping):
return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
return tensors[:limit]