mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Update logits_process.py docstrings (#25971)
This commit is contained in:
parent
3319eb5490
commit
d65c4a4fed
@ -272,7 +272,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
|
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
|
||||||
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
|
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
|
||||||
generation process, the probability distribution for the next token is determined using a formula that incorporates
|
generation process, the probability distribution for the next token is determined using a formula that incorporates
|
||||||
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
|
token scores based on their occurrence in the generated sequence. Tokens with higher scores are more likely to be
|
||||||
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
|
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
|
||||||
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
|
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
hallucination_penalty (`float`):
|
hallucination_penalty (`float`):
|
||||||
The parameter for hallucination penalty. 1.0 means no penalty.
|
The parameter for hallucination penalty. 1.0 means no penalty.
|
||||||
encoder_input_ids (`torch.LongTensor`):
|
encoder_input_ids (`torch.LongTensor`):
|
||||||
The encoder_input_ids that should not be repeated within the decoder ids.
|
The encoder_input_ids that should be repeated within the decoder ids.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
|
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
|
||||||
|
Loading…
Reference in New Issue
Block a user