mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix typing error in Trainer class (prediction_step) (#11138)
* fix: docstrings in prediction_step * ci: Satisfy line length requirements * ci: character length requirements
This commit is contained in:
parent
ffe0761777
commit
f8e90d6fb9
@ -1966,7 +1966,7 @@ class Trainer:
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
@ -1987,8 +1987,8 @@ class Trainer:
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||
labels (each being optional).
|
||||
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
||||
logits and labels (each being optional).
|
||||
"""
|
||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user