mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: add SequenceBiasLogitsProcessor (#24334)
This commit is contained in:
parent
45f71d793d
commit
5f0801d174
@ -141,6 +141,9 @@ generation.
|
||||
[[autodoc]] NoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SequenceBiasLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] NoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
@ -970,6 +970,7 @@ else:
|
||||
"PhrasalConstraint",
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"TemperatureLogitsWarper",
|
||||
@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
|
||||
PhrasalConstraint,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TemperatureLogitsWarper,
|
||||
|
@ -56,6 +56,7 @@ else:
|
||||
"NoRepeatNGramLogitsProcessor",
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
@ -182,6 +183,7 @@ if TYPE_CHECKING:
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
|
@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
bad_words_ids(`List[List[int]]`, *optional*):
|
||||
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
|
||||
should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing the
|
||||
tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
|
||||
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
|
||||
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||
List of list of token ids that are not allowed to be generated. Check
|
||||
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
||||
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
|
||||
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
|
||||
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
|
||||
@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
||||
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
||||
of index 123.
|
||||
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. Check
|
||||
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
||||
|
||||
> Parameters that define the output variables of `generate`
|
||||
|
||||
@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||
|
||||
# Parameters that define the output variables of `generate`
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -539,23 +539,208 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
return scores
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.
|
||||
[`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
|
||||
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
|
||||
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
|
||||
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
|
||||
|
||||
<Tip>
|
||||
|
||||
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
|
||||
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
|
||||
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
|
||||
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
sequence_bias (`Dict[Tuple[int], float]`):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
|
||||
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
|
||||
completed (in the token selection step after this processor is applied).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
|
||||
|
||||
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald J. Trump Jr
|
||||
|
||||
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
|
||||
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
|
||||
|
||||
|
||||
>>> def get_tokens_as_tuple(word):
|
||||
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
|
||||
|
||||
|
||||
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
|
||||
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald J. Donald,
|
||||
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald Rumsfeld,
|
||||
|
||||
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
|
||||
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald Duck.
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
|
||||
self.sequence_bias = sequence_bias
|
||||
self._validate_arguments()
|
||||
|
||||
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
|
||||
# is infered in the first usage, which inhibits initializing here)
|
||||
self.sequences_length_greater_than_1 = []
|
||||
self.length_1_bias = None
|
||||
self.length_greather_than_1_bias = None
|
||||
self.prepared_bias_variables = False
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
|
||||
if not self.prepared_bias_variables:
|
||||
self._prepare_bias_variables(scores)
|
||||
|
||||
# 2 - prepares an empty bias to add
|
||||
bias = torch.zeros_like(scores)
|
||||
|
||||
# 3 - include the bias from length = 1
|
||||
bias += self.length_1_bias
|
||||
|
||||
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
|
||||
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
|
||||
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
|
||||
# may become complete this iteration.
|
||||
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
|
||||
for sequence_ids in self.sequences_length_greater_than_1:
|
||||
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
|
||||
continue
|
||||
prefix_length = len(sequence_ids) - 1
|
||||
last_token = sequence_ids[-1]
|
||||
matching_rows = torch.eq(
|
||||
input_ids[:, -prefix_length:],
|
||||
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
|
||||
).prod(dim=1)
|
||||
matching_mask[:, last_token] |= matching_rows.bool()
|
||||
bias += torch.where(matching_mask, self.length_greather_than_1_bias, 0.0)
|
||||
|
||||
# 5 - apply the bias to the scores
|
||||
scores = scores + bias
|
||||
return scores
|
||||
|
||||
def _prepare_bias_variables(self, scores: torch.FloatTensor):
|
||||
vocabulary_size = scores.shape[-1]
|
||||
sequence_bias = self.sequence_bias
|
||||
tokens_with_bias = []
|
||||
|
||||
# Check biased tokens out of bounds
|
||||
invalid_biases = []
|
||||
for sequence_ids in sequence_bias:
|
||||
for token_id in sequence_ids:
|
||||
if token_id >= vocabulary_size:
|
||||
invalid_biases.append(token_id)
|
||||
if len(invalid_biases) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
|
||||
f"{invalid_biases}"
|
||||
)
|
||||
|
||||
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
|
||||
# with simpler logic.
|
||||
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
|
||||
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
|
||||
for sequence_ids, bias in sequence_bias.items():
|
||||
if len(sequence_ids) == 1:
|
||||
self.length_1_bias[sequence_ids[-1]] = bias
|
||||
else:
|
||||
self.sequences_length_greater_than_1.append(sequence_ids)
|
||||
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
|
||||
raise ValueError(
|
||||
"Setting a bias on sequences that share a common token termination is not yet supported. "
|
||||
"Please open an issue if you see this error message (after checking that it doesn't already "
|
||||
"exist)."
|
||||
)
|
||||
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
|
||||
tokens_with_bias.append(sequence_ids[-1])
|
||||
|
||||
self.prepared_bias_variables = True
|
||||
|
||||
def _validate_arguments(self):
|
||||
sequence_bias = self.sequence_bias
|
||||
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
|
||||
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
|
||||
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
|
||||
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
|
||||
if any(
|
||||
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
|
||||
or len(sequence_ids) == 0
|
||||
for sequence_ids in sequence_bias.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
|
||||
f"{sequence_bias}."
|
||||
)
|
||||
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
|
||||
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
||||
"""
|
||||
[`LogitsProcessor`] that enforces that specified sequences will never be selected.
|
||||
|
||||
<Tip>
|
||||
|
||||
In order to get the token ids of the words that should not appear in the generated text, make sure to set
|
||||
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
|
||||
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
|
||||
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
|
||||
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
bad_words_ids (`List[List[int]]`):
|
||||
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
|
||||
that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
|
||||
the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
|
||||
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
|
||||
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
||||
List of list of token ids that are not allowed to be generated.
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||
self.bad_word_ids = bad_words_ids
|
||||
self._validate_arguments()
|
||||
|
||||
# Filter EOS token from bad_words_ids
|
||||
if eos_token_id is None:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
bad_words_ids = list(
|
||||
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
||||
)
|
||||
|
||||
# Forbidding a sequence is equivalent to setting its bias to -inf
|
||||
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
|
||||
super().__init__(sequence_bias=sequence_bias)
|
||||
|
||||
def _validate_arguments(self):
|
||||
bad_words_ids = self.bad_word_ids
|
||||
if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
|
||||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
||||
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
||||
@ -567,113 +752,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||
)
|
||||
|
||||
if eos_token_id is None:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
bad_words_ids = list(
|
||||
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
|
||||
)
|
||||
self.bad_words_id_length_1 = []
|
||||
self.bad_words_id_length_greater_than_1 = []
|
||||
for word in bad_words_ids:
|
||||
if len(word) == 1:
|
||||
self.bad_words_id_length_1.append(word[0])
|
||||
else:
|
||||
self.bad_words_id_length_greater_than_1.append(word)
|
||||
|
||||
self.static_bad_words_mask: Optional[torch.LongTensor] = None
|
||||
|
||||
for banned_token_seq in self.bad_words_id_length_greater_than_1:
|
||||
if len(banned_token_seq) == 0:
|
||||
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
|
||||
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
|
||||
|
||||
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
|
||||
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
|
||||
|
||||
return scores
|
||||
|
||||
def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
|
||||
static_bad_words_mask = torch.zeros(scores.shape[1])
|
||||
static_bad_words_mask[self.bad_words_id_length_1] = 1
|
||||
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool()
|
||||
|
||||
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
|
||||
if len(tokens) == 0:
|
||||
# if bad word tokens is just one token always ban it
|
||||
return True
|
||||
elif len(tokens) > len(prev_tokens):
|
||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
||||
return False
|
||||
else:
|
||||
return prev_tokens[-len(tokens) :] == tokens
|
||||
|
||||
def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
|
||||
banned_tokens = []
|
||||
for prev_input_ids_slice in prev_input_ids:
|
||||
banned_tokens_slice = []
|
||||
for banned_token_seq in self.bad_words_id_length_greater_than_1:
|
||||
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
|
||||
banned_tokens_slice.append(banned_token_seq[-1])
|
||||
|
||||
banned_tokens.append(banned_tokens_slice)
|
||||
|
||||
return banned_tokens
|
||||
|
||||
def _set_scores_to_inf_for_banned_tokens(
|
||||
self, scores: torch.Tensor, banned_tokens: List[List[int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
||||
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
||||
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||
"""
|
||||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
# Eliminates invalid bad word IDs that are over the vocabulary size.
|
||||
if token <= scores.shape[1]:
|
||||
banned_mask_list.append([idx, token])
|
||||
else:
|
||||
logger.error(
|
||||
f"An invalid bad word ID is defined: {token}. This ID is not contained in the "
|
||||
"vocabulary, and is therefore ignored."
|
||||
)
|
||||
if not banned_mask_list and self.static_bad_words_mask is None:
|
||||
return scores
|
||||
|
||||
else:
|
||||
if banned_mask_list:
|
||||
indices = torch.ones(len(banned_mask_list))
|
||||
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
# [ 1 0 0 ]
|
||||
|
||||
banned_mask = (
|
||||
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
|
||||
.to(scores.device)
|
||||
.to_dense()
|
||||
.bool()
|
||||
)
|
||||
|
||||
if self.static_bad_words_mask is not None:
|
||||
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask)
|
||||
else:
|
||||
banned_mask = self.static_bad_words_mask
|
||||
|
||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
||||
return scores
|
||||
|
||||
|
||||
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
|
@ -56,6 +56,7 @@ from .logits_process import (
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
@ -842,8 +843,9 @@ class GenerationMixin:
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.sequence_bias is not None:
|
||||
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
||||
|
||||
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
|
||||
processors.append(
|
||||
HammingDiversityLogitsProcessor(
|
||||
|
@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SequenceBiasLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StoppingCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -46,6 +46,7 @@ if is_torch_available():
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
||||
|
||||
def test_bias_dist_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
positive_bias = {(1,): 100.0, (4,): 100.0}
|
||||
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
|
||||
sequence_bias = {**positive_bias, **negative_bias}
|
||||
|
||||
# scores = 0 to facilitate checks
|
||||
scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device)
|
||||
|
||||
bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias)
|
||||
filtered_scores = bias_dist_proc(input_ids, scores.clone())
|
||||
|
||||
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
|
||||
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
|
||||
self.assertListEqual(
|
||||
filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]]
|
||||
)
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
|
Loading…
Reference in New Issue
Block a user