mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Extend nested_XXX
functions to mappings/dicts. (#19455)
* Extend `nested_XXX` functions to mappings/dicts. * Update src/transformers/trainer_pt_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer_pt_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer_pt_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Style updated file Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b722a6be72
commit
335f9bcd34
@ -104,7 +104,7 @@ def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
||||
def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||
"""
|
||||
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
|
||||
nested list/tuples of tensors.
|
||||
nested list/tuples/dict of tensors.
|
||||
"""
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
@ -113,6 +113,10 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)(
|
||||
{k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
|
||||
)
|
||||
elif isinstance(tensors, np.ndarray):
|
||||
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||
else:
|
||||
@ -140,9 +144,12 @@ def find_batch_size(tensors):
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
"Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
|
||||
|
||||
t = tensors.cpu()
|
||||
if t.dtype == torch.bfloat16:
|
||||
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
|
||||
@ -153,9 +160,11 @@ def nested_numpify(tensors):
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
"Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
@ -165,6 +174,11 @@ def nested_xla_mesh_reduce(tensors, name):
|
||||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
if isinstance(tensors, Mapping):
|
||||
return type(tensors)(
|
||||
{k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
|
||||
)
|
||||
|
||||
tensors = atleast_1d(tensors)
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
@ -335,9 +349,12 @@ def expand_like(arrays, new_seq_length, padding_index=-100):
|
||||
|
||||
|
||||
def nested_truncate(tensors, limit):
|
||||
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
|
||||
"Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_truncate(t, limit) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
|
||||
|
||||
return tensors[:limit]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user