mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Generate: add SequenceBiasLogitsProcessor (#24334)
This commit is contained in:
parent
45f71d793d
commit
5f0801d174
@ -141,6 +141,9 @@ generation.
|
|||||||
[[autodoc]] NoRepeatNGramLogitsProcessor
|
[[autodoc]] NoRepeatNGramLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] SequenceBiasLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] NoBadWordsLogitsProcessor
|
[[autodoc]] NoBadWordsLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
@ -970,6 +970,7 @@ else:
|
|||||||
"PhrasalConstraint",
|
"PhrasalConstraint",
|
||||||
"PrefixConstrainedLogitsProcessor",
|
"PrefixConstrainedLogitsProcessor",
|
||||||
"RepetitionPenaltyLogitsProcessor",
|
"RepetitionPenaltyLogitsProcessor",
|
||||||
|
"SequenceBiasLogitsProcessor",
|
||||||
"StoppingCriteria",
|
"StoppingCriteria",
|
||||||
"StoppingCriteriaList",
|
"StoppingCriteriaList",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
|
|||||||
PhrasalConstraint,
|
PhrasalConstraint,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
SequenceBiasLogitsProcessor,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
|
@ -56,6 +56,7 @@ else:
|
|||||||
"NoRepeatNGramLogitsProcessor",
|
"NoRepeatNGramLogitsProcessor",
|
||||||
"PrefixConstrainedLogitsProcessor",
|
"PrefixConstrainedLogitsProcessor",
|
||||||
"RepetitionPenaltyLogitsProcessor",
|
"RepetitionPenaltyLogitsProcessor",
|
||||||
|
"SequenceBiasLogitsProcessor",
|
||||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
@ -182,6 +183,7 @@ if TYPE_CHECKING:
|
|||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
SequenceBiasLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
||||||
If set to int > 0, all ngrams of that size can only occur once.
|
If set to int > 0, all ngrams of that size can only occur once.
|
||||||
bad_words_ids(`List[List[int]]`, *optional*):
|
bad_words_ids(`List[List[int]]`, *optional*):
|
||||||
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
|
List of list of token ids that are not allowed to be generated. Check
|
||||||
should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing the
|
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
||||||
tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
|
|
||||||
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
|
|
||||||
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
|
||||||
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
|
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
|
||||||
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
|
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
|
||||||
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
|
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
|
||||||
@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
||||||
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
||||||
of index 123.
|
of index 123.
|
||||||
|
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
||||||
|
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||||
|
sequence being selected, while negative biases do the opposite. Check
|
||||||
|
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
||||||
|
|
||||||
> Parameters that define the output variables of `generate`
|
> Parameters that define the output variables of `generate`
|
||||||
|
|
||||||
@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
||||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||||
|
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||||
|
|
||||||
# Parameters that define the output variables of `generate`
|
# Parameters that define the output variables of `generate`
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -539,23 +539,208 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.
|
[`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
|
||||||
|
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
|
||||||
|
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
|
||||||
|
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
|
||||||
|
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
|
||||||
|
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
|
||||||
|
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_bias (`Dict[Tuple[int], float]`):
|
||||||
|
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||||
|
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
|
||||||
|
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
|
||||||
|
completed (in the token selection step after this processor is applied).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
|
||||||
|
|
||||||
|
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
|
||||||
|
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
||||||
|
The full name of Donald is Donald J. Trump Jr
|
||||||
|
|
||||||
|
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
|
||||||
|
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
|
||||||
|
|
||||||
|
|
||||||
|
>>> def get_tokens_as_tuple(word):
|
||||||
|
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
|
||||||
|
|
||||||
|
|
||||||
|
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
|
||||||
|
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
|
||||||
|
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
|
||||||
|
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||||
|
The full name of Donald is Donald J. Donald,
|
||||||
|
|
||||||
|
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||||
|
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||||
|
The full name of Donald is Donald Rumsfeld,
|
||||||
|
|
||||||
|
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
|
||||||
|
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
|
||||||
|
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||||
|
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||||
|
The full name of Donald is Donald Duck.
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
|
||||||
|
self.sequence_bias = sequence_bias
|
||||||
|
self._validate_arguments()
|
||||||
|
|
||||||
|
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
|
||||||
|
# is infered in the first usage, which inhibits initializing here)
|
||||||
|
self.sequences_length_greater_than_1 = []
|
||||||
|
self.length_1_bias = None
|
||||||
|
self.length_greather_than_1_bias = None
|
||||||
|
self.prepared_bias_variables = False
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
|
||||||
|
if not self.prepared_bias_variables:
|
||||||
|
self._prepare_bias_variables(scores)
|
||||||
|
|
||||||
|
# 2 - prepares an empty bias to add
|
||||||
|
bias = torch.zeros_like(scores)
|
||||||
|
|
||||||
|
# 3 - include the bias from length = 1
|
||||||
|
bias += self.length_1_bias
|
||||||
|
|
||||||
|
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
|
||||||
|
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
|
||||||
|
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
|
||||||
|
# may become complete this iteration.
|
||||||
|
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
|
||||||
|
for sequence_ids in self.sequences_length_greater_than_1:
|
||||||
|
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
|
||||||
|
continue
|
||||||
|
prefix_length = len(sequence_ids) - 1
|
||||||
|
last_token = sequence_ids[-1]
|
||||||
|
matching_rows = torch.eq(
|
||||||
|
input_ids[:, -prefix_length:],
|
||||||
|
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
|
||||||
|
).prod(dim=1)
|
||||||
|
matching_mask[:, last_token] |= matching_rows.bool()
|
||||||
|
bias += torch.where(matching_mask, self.length_greather_than_1_bias, 0.0)
|
||||||
|
|
||||||
|
# 5 - apply the bias to the scores
|
||||||
|
scores = scores + bias
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def _prepare_bias_variables(self, scores: torch.FloatTensor):
|
||||||
|
vocabulary_size = scores.shape[-1]
|
||||||
|
sequence_bias = self.sequence_bias
|
||||||
|
tokens_with_bias = []
|
||||||
|
|
||||||
|
# Check biased tokens out of bounds
|
||||||
|
invalid_biases = []
|
||||||
|
for sequence_ids in sequence_bias:
|
||||||
|
for token_id in sequence_ids:
|
||||||
|
if token_id >= vocabulary_size:
|
||||||
|
invalid_biases.append(token_id)
|
||||||
|
if len(invalid_biases) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
|
||||||
|
f"{invalid_biases}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
|
||||||
|
# with simpler logic.
|
||||||
|
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
|
||||||
|
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
|
||||||
|
for sequence_ids, bias in sequence_bias.items():
|
||||||
|
if len(sequence_ids) == 1:
|
||||||
|
self.length_1_bias[sequence_ids[-1]] = bias
|
||||||
|
else:
|
||||||
|
self.sequences_length_greater_than_1.append(sequence_ids)
|
||||||
|
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
"Setting a bias on sequences that share a common token termination is not yet supported. "
|
||||||
|
"Please open an issue if you see this error message (after checking that it doesn't already "
|
||||||
|
"exist)."
|
||||||
|
)
|
||||||
|
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
|
||||||
|
tokens_with_bias.append(sequence_ids[-1])
|
||||||
|
|
||||||
|
self.prepared_bias_variables = True
|
||||||
|
|
||||||
|
def _validate_arguments(self):
|
||||||
|
sequence_bias = self.sequence_bias
|
||||||
|
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
|
||||||
|
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
|
||||||
|
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
|
||||||
|
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
|
||||||
|
if any(
|
||||||
|
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
|
||||||
|
or len(sequence_ids) == 0
|
||||||
|
for sequence_ids in sequence_bias.keys()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
|
||||||
|
f"{sequence_bias}."
|
||||||
|
)
|
||||||
|
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
|
||||||
|
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
|
||||||
|
|
||||||
|
|
||||||
|
class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
||||||
|
"""
|
||||||
|
[`LogitsProcessor`] that enforces that specified sequences will never be selected.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
In order to get the token ids of the words that should not appear in the generated text, make sure to set
|
||||||
|
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
|
||||||
|
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
|
||||||
|
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
|
||||||
|
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bad_words_ids (`List[List[int]]`):
|
bad_words_ids (`List[List[int]]`):
|
||||||
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
List of list of token ids that are not allowed to be generated.
|
||||||
that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
|
|
||||||
the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
|
|
||||||
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
|
|
||||||
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
|
||||||
eos_token_id (`Union[int, List[int]]`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
self.bad_word_ids = bad_words_ids
|
||||||
|
self._validate_arguments()
|
||||||
|
|
||||||
|
# Filter EOS token from bad_words_ids
|
||||||
|
if eos_token_id is None:
|
||||||
|
eos_token_id = []
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
bad_words_ids = list(
|
||||||
|
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forbidding a sequence is equivalent to setting its bias to -inf
|
||||||
|
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
|
||||||
|
super().__init__(sequence_bias=sequence_bias)
|
||||||
|
|
||||||
|
def _validate_arguments(self):
|
||||||
|
bad_words_ids = self.bad_word_ids
|
||||||
|
if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
|
||||||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
||||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
||||||
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
||||||
@ -567,113 +752,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if eos_token_id is None:
|
|
||||||
eos_token_id = []
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
|
|
||||||
bad_words_ids = list(
|
|
||||||
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
|
||||||
)
|
|
||||||
self.bad_words_id_length_1 = []
|
|
||||||
self.bad_words_id_length_greater_than_1 = []
|
|
||||||
for word in bad_words_ids:
|
|
||||||
if len(word) == 1:
|
|
||||||
self.bad_words_id_length_1.append(word[0])
|
|
||||||
else:
|
|
||||||
self.bad_words_id_length_greater_than_1.append(word)
|
|
||||||
|
|
||||||
self.static_bad_words_mask: Optional[torch.LongTensor] = None
|
|
||||||
|
|
||||||
for banned_token_seq in self.bad_words_id_length_greater_than_1:
|
|
||||||
if len(banned_token_seq) == 0:
|
|
||||||
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
|
|
||||||
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
|
|
||||||
|
|
||||||
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
|
|
||||||
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
|
|
||||||
static_bad_words_mask = torch.zeros(scores.shape[1])
|
|
||||||
static_bad_words_mask[self.bad_words_id_length_1] = 1
|
|
||||||
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool()
|
|
||||||
|
|
||||||
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
|
|
||||||
if len(tokens) == 0:
|
|
||||||
# if bad word tokens is just one token always ban it
|
|
||||||
return True
|
|
||||||
elif len(tokens) > len(prev_tokens):
|
|
||||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return prev_tokens[-len(tokens) :] == tokens
|
|
||||||
|
|
||||||
def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
|
|
||||||
banned_tokens = []
|
|
||||||
for prev_input_ids_slice in prev_input_ids:
|
|
||||||
banned_tokens_slice = []
|
|
||||||
for banned_token_seq in self.bad_words_id_length_greater_than_1:
|
|
||||||
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
|
|
||||||
banned_tokens_slice.append(banned_token_seq[-1])
|
|
||||||
|
|
||||||
banned_tokens.append(banned_tokens_slice)
|
|
||||||
|
|
||||||
return banned_tokens
|
|
||||||
|
|
||||||
def _set_scores_to_inf_for_banned_tokens(
|
|
||||||
self, scores: torch.Tensor, banned_tokens: List[List[int]]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
|
||||||
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scores: logits distribution of shape (batch size, vocabulary size)
|
|
||||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
|
||||||
"""
|
|
||||||
banned_mask_list = []
|
|
||||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
|
||||||
for token in batch_banned_tokens:
|
|
||||||
# Eliminates invalid bad word IDs that are over the vocabulary size.
|
|
||||||
if token <= scores.shape[1]:
|
|
||||||
banned_mask_list.append([idx, token])
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"An invalid bad word ID is defined: {token}. This ID is not contained in the "
|
|
||||||
"vocabulary, and is therefore ignored."
|
|
||||||
)
|
|
||||||
if not banned_mask_list and self.static_bad_words_mask is None:
|
|
||||||
return scores
|
|
||||||
|
|
||||||
else:
|
|
||||||
if banned_mask_list:
|
|
||||||
indices = torch.ones(len(banned_mask_list))
|
|
||||||
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
|
|
||||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
|
||||||
# [ 0 1 1 ]
|
|
||||||
# [ 0 0 0 ]
|
|
||||||
# [ 1 0 0 ]
|
|
||||||
|
|
||||||
banned_mask = (
|
|
||||||
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
|
|
||||||
.to(scores.device)
|
|
||||||
.to_dense()
|
|
||||||
.bool()
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.static_bad_words_mask is not None:
|
|
||||||
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask)
|
|
||||||
else:
|
|
||||||
banned_mask = self.static_bad_words_mask
|
|
||||||
|
|
||||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
|
@ -56,6 +56,7 @@ from .logits_process import (
|
|||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
SequenceBiasLogitsProcessor,
|
||||||
SuppressTokensAtBeginLogitsProcessor,
|
SuppressTokensAtBeginLogitsProcessor,
|
||||||
SuppressTokensLogitsProcessor,
|
SuppressTokensLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
@ -842,8 +843,9 @@ class GenerationMixin:
|
|||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
if generation_config.sequence_bias is not None:
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
||||||
|
|
||||||
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
|
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
|
||||||
processors.append(
|
processors.append(
|
||||||
HammingDiversityLogitsProcessor(
|
HammingDiversityLogitsProcessor(
|
||||||
|
@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceBiasLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteria(metaclass=DummyObject):
|
class StoppingCriteria(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ if is_torch_available():
|
|||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
SequenceBiasLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||||
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_bias_dist_processor(self):
|
||||||
|
vocab_size = 5
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||||
|
positive_bias = {(1,): 100.0, (4,): 100.0}
|
||||||
|
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
|
||||||
|
sequence_bias = {**positive_bias, **negative_bias}
|
||||||
|
|
||||||
|
# scores = 0 to facilitate checks
|
||||||
|
scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device)
|
||||||
|
|
||||||
|
bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias)
|
||||||
|
filtered_scores = bias_dist_proc(input_ids, scores.clone())
|
||||||
|
|
||||||
|
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
|
||||||
|
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
|
||||||
|
self.assertListEqual(
|
||||||
|
filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]]
|
||||||
|
)
|
||||||
|
|
||||||
def test_processor_list(self):
|
def test_processor_list(self):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
sequence_length = 10
|
sequence_length = 10
|
||||||
|
Loading…
Reference in New Issue
Block a user