Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor (#37625)

* Allow exclusion of input IDs for repetition penalty

* Add logit proc tests for rep penalty exclusion

* Expose rep pen flag through generate

* Only slice if needed

* keep current rep pen default behavior

* Revert exposing reppen changes through generate

* Fix test arg

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Rename to rep penalty kwarg

* Add custom repetition penalty processor example

* Validate prompt_ignore_length

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Alex Brooks 2025-04-21 08:46:05 -06:00 committed by GitHub
parent 1077603410
commit a42ba80fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 3 deletions

View File

@ -292,7 +292,8 @@ class TemperatureLogitsWarper(LogitsProcessor):
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt. most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
by default.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
@ -303,11 +304,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
penalty (`float`): penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. tokens. Between 0.0 and 1.0 rewards previously generated tokens.
prompt_ignore_length (`int`, *optional*):
The original input ids sequence length, which if provided, will not be used in the penalty calculation.
Examples: Examples:
```py ```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
>>> # Initializing the model and tokenizer for it >>> # Initializing the model and tokenizer for it
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
@ -323,17 +326,36 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
>>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1) >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0]) >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'll just have to go out and play I'm not going to be able to do that. I'll just have to go out and play
>>> # We can also exclude the input prompt by creating an instance of this class
>>> # with a `prompt_ignore_length` and passing it as a custom logit processor
>>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
... penalty=1.1,
... prompt_ignore_length=inputs["input_ids"].shape[-1]
... )
>>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'm going to have to go through a lot of things, and
``` ```
""" """
def __init__(self, penalty: float): def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None):
if not isinstance(penalty, float) or not (penalty > 0): if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
if prompt_ignore_length is not None and (
not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
):
raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
self.penalty = penalty self.penalty = penalty
self.prompt_ignore_length = prompt_ignore_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities

View File

@ -203,6 +203,56 @@ class LogitsProcessorTest(unittest.TestCase):
# processor should not change logits in-place # processor should not change logits in-place
self.assertFalse(torch.all(scores == processed_scores)) self.assertFalse(torch.all(scores == processed_scores))
def test_repetition_penalty_dist_process_exclusion_no_new_input_ids(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
penalty=2.0,
prompt_ignore_length=input_ids.shape[-1],
)
processed_scores = rep_penalty_proc(input_ids, scores)
# Because input IDs were provided & we call with the same input
# IDs that we initialize with, it should be the same as calling
# with no input IDs, so no scores should be penalized.
self.assertTrue(torch.all(scores == processed_scores))
def test_repetition_penalty_dist_process_exclusion_with_new_input_ids(self):
orig_input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
curr_input_ids = torch.tensor([[0, 1, 0, 1], [5, 0, 5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
penalty=2.0,
prompt_ignore_length=orig_input_ids.shape[-1],
)
processed_scores = rep_penalty_proc(curr_input_ids, scores)
# check that values were correctly changed
self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) / 2)
# processor should not change logits in-place
self.assertFalse(torch.all(scores == processed_scores))
def test_encoder_repetition_penalty_dist_process(self): def test_encoder_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10 vocab_size = 10