mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[DOCS] Example for LogitsProcessor
class (#24848)
* make docs
* fixup
* resolved
* remove debugs
* Revert "fixup"
This reverts commit 5e0f636aae
.
* prev (ignore)
* fixup broke some files
* remove files
* reverting modeling_reformer
* lang fix
This commit is contained in:
parent
35c04596f8
commit
0c41765df4
@ -193,12 +193,38 @@ class TemperatureLogitsWarper(LogitsWarper):
|
||||
|
||||
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
|
||||
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
|
||||
generation process, the probability distribution for the next token is determined using a formula that incorporates
|
||||
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
|
||||
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
|
||||
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
|
||||
|
||||
Args:
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> # Initializing the model and tokenizer for it
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
|
||||
|
||||
>>> # This shows a normal generate without any specific parameters
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], max_length=20)
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
||||
I'm not going to lie, I'm not going to lie. I'm not going to lie
|
||||
|
||||
>>> # This generates a penalty for repeated tokens
|
||||
>>> penalized_ids = model.generate(inputs["input_ids"], max_length=20, repetition_penalty=1.2)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
I'm not going to lie, I was really excited about this. It's a great game
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
|
Loading…
Reference in New Issue
Block a user