Fix StoppingCriteria ABC signature (#12918)

Change `score` -> `scores` because the argument is not positional-only, so you need consistently named parameters for the subclasses. The subclasses appear to favor `scores` over `score`.
This commit is contained in:
Will Frey 2021-07-28 12:47:15 -04:00 committed by GitHub
parent 63f2b9ab33
commit bf78f523aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,7 +35,7 @@ class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")