mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor (#37625)
* Allow exclusion of input IDs for repetition penalty * Add logit proc tests for rep penalty exclusion * Expose rep pen flag through generate * Only slice if needed * keep current rep pen default behavior * Revert exposing reppen changes through generate * Fix test arg * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Rename to rep penalty kwarg * Add custom repetition penalty processor example * Validate prompt_ignore_length --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
1077603410
commit
a42ba80fa5
@ -292,7 +292,8 @@ class TemperatureLogitsWarper(LogitsProcessor):
|
||||
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
|
||||
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
|
||||
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
|
||||
by default.
|
||||
|
||||
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
|
||||
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
|
||||
@ -303,11 +304,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
|
||||
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
|
||||
prompt_ignore_length (`int`, *optional*):
|
||||
The original input ids sequence length, which if provided, will not be used in the penalty calculation.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
>>> # Initializing the model and tokenizer for it
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
|
||||
@ -323,17 +326,36 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
>>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
|
||||
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
|
||||
I'm not going to be able to do that. I'll just have to go out and play
|
||||
|
||||
>>> # We can also exclude the input prompt by creating an instance of this class
|
||||
>>> # with a `prompt_ignore_length` and passing it as a custom logit processor
|
||||
>>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
|
||||
... penalty=1.1,
|
||||
... prompt_ignore_length=inputs["input_ids"].shape[-1]
|
||||
... )
|
||||
>>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
|
||||
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
|
||||
I'm not going to be able to do that. I'm going to have to go through a lot of things, and
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
if prompt_ignore_length is not None and (
|
||||
not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
|
||||
):
|
||||
raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
|
||||
|
||||
self.penalty = penalty
|
||||
self.prompt_ignore_length = prompt_ignore_length
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.prompt_ignore_length:
|
||||
input_ids = input_ids[:, self.prompt_ignore_length :]
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||
|
@ -203,6 +203,56 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_dist_process_exclusion_no_new_input_ids(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=2.0,
|
||||
prompt_ignore_length=input_ids.shape[-1],
|
||||
)
|
||||
|
||||
processed_scores = rep_penalty_proc(input_ids, scores)
|
||||
|
||||
# Because input IDs were provided & we call with the same input
|
||||
# IDs that we initialize with, it should be the same as calling
|
||||
# with no input IDs, so no scores should be penalized.
|
||||
self.assertTrue(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_dist_process_exclusion_with_new_input_ids(self):
|
||||
orig_input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
curr_input_ids = torch.tensor([[0, 1, 0, 1], [5, 0, 5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=2.0,
|
||||
prompt_ignore_length=orig_input_ids.shape[-1],
|
||||
)
|
||||
|
||||
processed_scores = rep_penalty_proc(curr_input_ids, scores)
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) / 2)
|
||||
|
||||
self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_encoder_repetition_penalty_dist_process(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
Loading…
Reference in New Issue
Block a user