mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Allow nested tensors in predicted logits (#7542)
This commit is contained in:
parent
60de910e60
commit
0270256b27
@ -48,6 +48,7 @@ from .trainer_utils import (
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
nested_xla_mesh_reduce,
|
||||
set_seed,
|
||||
@ -1466,16 +1467,18 @@ class Trainer:
|
||||
logits = outputs[:]
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||
# Remove the past from the logits.
|
||||
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = tuple(logit.detach() for logit in logits)
|
||||
logits = nested_detach(logits)
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
if has_labels:
|
||||
labels = tuple(inputs.get(name).detach() for name in self.label_names)
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
|
@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0):
|
||||
raise ImportError("Torch must be installed to use `nested_concat`")
|
||||
|
||||
|
||||
def nested_deatch(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
|
Loading…
Reference in New Issue
Block a user