Small fix to ExponentialDecayLengthPenalty docstring (#21308)

Currently, it incorrectly states that the exponential_decay_length_penalty tuple parameter is optional.

Also changed the corresponding type hint to be more specific.
This commit is contained in:
Nick Hill 2023-01-25 11:46:08 -08:00 committed by GitHub
parent 3a6e4a221c
commit 140c6edeb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -825,7 +825,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
reached.
Args:
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
@ -835,7 +835,10 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
"""
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]