Make beam_constraints.Constraint.advance() docstring more accurate (#32674)

* Fix beam_constraints.Constraint.advance() docstring

* Update src/transformers/generation/beam_constraints.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Alex Calderwood 2024-08-16 11:36:55 -07:00 committed by GitHub
parent 8ec028aded
commit 6806d33567
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,10 +48,13 @@ class Constraint(ABC):
@abstractmethod
def advance(self):
"""
When called, returns the token that would take this constraint one step closer to being fulfilled.
When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
Return:
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
token_ids (Union[int, List[int], None]):
- A single token ID (int) that advances the constraint, or
- A list of token IDs that could advance the constraint
- None if the constraint is completed or cannot be advanced
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."