Fix typing error in Trainer class (prediction_step) (#11138)

* fix: docstrings in prediction_step

* ci: Satisfy line length requirements

* ci: character length requirements
This commit is contained in:
Jannis Born 2021-04-08 14:22:25 +02:00 committed by GitHub
parent ffe0761777
commit f8e90d6fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1966,7 +1966,7 @@ class Trainer:
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
@ -1987,8 +1987,8 @@ class Trainer:
gathering predictions.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)