mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Small fix to ExponentialDecayLengthPenalty docstring (#21308)
Currently, it incorrectly states that the exponential_decay_length_penalty tuple parameter is optional. Also changed the corresponding type hint to be more specific.
This commit is contained in:
parent
3a6e4a221c
commit
140c6edeb9
@ -825,7 +825,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||
reached.
|
||||
|
||||
Args:
|
||||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||
exponential_decay_length_penalty (`tuple(int, float)`):
|
||||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||
starts and `decay_factor` represents the factor of exponential decay
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
@ -835,7 +835,10 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
|
||||
self,
|
||||
exponential_decay_length_penalty: Tuple[int, float],
|
||||
eos_token_id: Union[int, List[int]],
|
||||
input_ids_seq_length: int,
|
||||
):
|
||||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||
|
Loading…
Reference in New Issue
Block a user