From f8e90d6fb962716d0a79f71013f2485ac0982146 Mon Sep 17 00:00:00 2001 From: Jannis Born Date: Thu, 8 Apr 2021 14:22:25 +0200 Subject: [PATCH] Fix typing error in Trainer class (prediction_step) (#11138) * fix: docstrings in prediction_step * ci: Satisfy line length requirements * ci: character length requirements --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7c33981b6d9..33c14d921ca 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1966,7 +1966,7 @@ class Trainer: inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, - ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. @@ -1987,8 +1987,8 @@ class Trainer: gathering predictions. Return: - Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and - labels (each being optional). + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs)