Allow nested tensors in predicted logits (#7542)

This commit is contained in:
Sylvain Gugger 2020-10-05 06:33:15 -04:00 committed by GitHub
parent 60de910e60
commit 0270256b27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View File

@ -48,6 +48,7 @@ from .trainer_utils import (
distributed_broadcast_scalars,
distributed_concat,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
set_seed,
@ -1466,16 +1467,18 @@ class Trainer:
logits = outputs[:]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only:
return (loss, None, None)
logits = tuple(logit.detach() for logit in logits)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.label_names)
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:

View File

@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0):
raise ImportError("Torch must be installed to use `nested_concat`")
def nested_deatch(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):