From a42ba80fa520c784c8f11a973ca9034e5f859b79 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 21 Apr 2025 08:46:05 -0600 Subject: [PATCH] Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor (#37625) * Allow exclusion of input IDs for repetition penalty * Add logit proc tests for rep penalty exclusion * Expose rep pen flag through generate * Only slice if needed * keep current rep pen default behavior * Revert exposing reppen changes through generate * Fix test arg * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Rename to rep penalty kwarg * Add custom repetition penalty processor example * Validate prompt_ignore_length --------- Co-authored-by: Joao Gante --- src/transformers/generation/logits_process.py | 28 +++++++++-- tests/generation/test_logits_process.py | 50 +++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 16c04478f08..352fff9e637 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -292,7 +292,8 @@ class TemperatureLogitsWarper(LogitsProcessor): class RepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at - most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt. + most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt + by default. In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce @@ -303,11 +304,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. Between 0.0 and 1.0 rewards previously generated tokens. + prompt_ignore_length (`int`, *optional*): + The original input ids sequence length, which if provided, will not be used in the penalty calculation. Examples: ```py - >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor >>> # Initializing the model and tokenizer for it >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2") @@ -323,17 +326,36 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1) >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0]) I'm not going to be able to do that. I'll just have to go out and play + + >>> # We can also exclude the input prompt by creating an instance of this class + >>> # with a `prompt_ignore_length` and passing it as a custom logit processor + >>> rep_pen_processor = RepetitionPenaltyLogitsProcessor( + ... penalty=1.1, + ... prompt_ignore_length=inputs["input_ids"].shape[-1] + ... ) + >>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor]) + >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0]) + I'm not going to be able to do that. I'm going to have to go through a lot of things, and ``` """ - def __init__(self, penalty: float): + def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + if prompt_ignore_length is not None and ( + not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0 + ): + raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}") + self.penalty = penalty + self.prompt_ignore_length = prompt_ignore_length @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.prompt_ignore_length: + input_ids = input_ids[:, self.prompt_ignore_length :] + score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 2647e6677d0..ea0a7581e5c 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -203,6 +203,56 @@ class LogitsProcessorTest(unittest.TestCase): # processor should not change logits in-place self.assertFalse(torch.all(scores == processed_scores)) + def test_repetition_penalty_dist_process_exclusion_no_new_input_ids(self): + input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + # give values special values + scores[0, 0] = -(1 / vocab_size) + scores[1, 5] = 4 / vocab_size + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor( + penalty=2.0, + prompt_ignore_length=input_ids.shape[-1], + ) + + processed_scores = rep_penalty_proc(input_ids, scores) + + # Because input IDs were provided & we call with the same input + # IDs that we initialize with, it should be the same as calling + # with no input IDs, so no scores should be penalized. + self.assertTrue(torch.all(scores == processed_scores)) + + def test_repetition_penalty_dist_process_exclusion_with_new_input_ids(self): + orig_input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + curr_input_ids = torch.tensor([[0, 1, 0, 1], [5, 0, 5, 0]], device=torch_device, dtype=torch.long) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + # give values special values + scores[0, 0] = -(1 / vocab_size) + scores[1, 5] = 4 / vocab_size + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor( + penalty=2.0, + prompt_ignore_length=orig_input_ids.shape[-1], + ) + + processed_scores = rep_penalty_proc(curr_input_ids, scores) + + # check that values were correctly changed + self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) * 2) + self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) / 2) + + self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) / 2) + self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) / 2) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) + def test_encoder_repetition_penalty_dist_process(self): input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) vocab_size = 10