From 6806d3356791d8168e9673ad61b28fded4ed80e8 Mon Sep 17 00:00:00 2001 From: Alex Calderwood Date: Fri, 16 Aug 2024 11:36:55 -0700 Subject: [PATCH] Make beam_constraints.Constraint.advance() docstring more accurate (#32674) * Fix beam_constraints.Constraint.advance() docstring * Update src/transformers/generation/beam_constraints.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Joao Gante Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- src/transformers/generation/beam_constraints.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/beam_constraints.py b/src/transformers/generation/beam_constraints.py index e6462f322c4..daf64209b79 100644 --- a/src/transformers/generation/beam_constraints.py +++ b/src/transformers/generation/beam_constraints.py @@ -48,10 +48,13 @@ class Constraint(ABC): @abstractmethod def advance(self): """ - When called, returns the token that would take this constraint one step closer to being fulfilled. + When called, returns the token(s) that would take this constraint one step closer to being fulfilled. Return: - token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer. + token_ids (Union[int, List[int], None]): + - A single token ID (int) that advances the constraint, or + - A list of token IDs that could advance the constraint + - None if the constraint is completed or cannot be advanced """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."