mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
change sequence_bias type of SequenceBiasLogitsProcessor to list, add… (#33375)
* change sequence_bias type of SequenceBiasLogitsProcessor tp list, add config tests for all processors * fix format * small fix for all_token_bias_pairs_are_valid internal func * small typo fix in description * improve test impl, some SequenceBiasLogitsProcessor refactoring
This commit is contained in:
parent
d9d59e7bac
commit
162056a3f4
@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1064,8 +1064,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
sequence_bias (`Dict[Tuple[int], float]`):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence_bias (`List[List[Union[List[int], float]]]`):
|
||||
List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0],
|
||||
[[64], -7.5]]`). Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
|
||||
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
|
||||
completed (in the token selection step after this processor is applied).
|
||||
@ -1087,12 +1088,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
|
||||
|
||||
|
||||
>>> def get_tokens_as_tuple(word):
|
||||
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
|
||||
>>> def get_tokens(word):
|
||||
... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
|
||||
|
||||
|
||||
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
|
||||
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
|
||||
>>> sequence_bias = [get_tokens("Trump"), -10.0]
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald J. Donald,
|
||||
@ -1102,16 +1103,17 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
The full name of Donald is Donald Rumsfeld,
|
||||
|
||||
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
|
||||
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
|
||||
>>> sequence_bias = [get_tokens("Donald Duck"), 10.0]
|
||||
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
||||
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
||||
The full name of Donald is Donald Duck.
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
|
||||
def __init__(self, sequence_bias: List[List[Union[List[int], float]]]):
|
||||
self.sequence_bias = sequence_bias
|
||||
self._validate_arguments()
|
||||
self._convert_list_arguments_into_dict()
|
||||
|
||||
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
|
||||
# is infered in the first usage, which inhibits initializing here)
|
||||
@ -1178,11 +1180,15 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def _validate_arguments(self):
|
||||
sequence_bias = self.sequence_bias
|
||||
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
|
||||
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
|
||||
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
|
||||
if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0:
|
||||
raise ValueError(
|
||||
f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}."
|
||||
)
|
||||
if isinstance(sequence_bias, dict) and any(
|
||||
not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
|
||||
):
|
||||
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
|
||||
if any(
|
||||
if isinstance(sequence_bias, dict) and any(
|
||||
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
|
||||
or len(sequence_ids) == 0
|
||||
for sequence_ids in sequence_bias.keys()
|
||||
@ -1191,9 +1197,30 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
||||
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
|
||||
f"{sequence_bias}."
|
||||
)
|
||||
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
|
||||
|
||||
def all_token_bias_pairs_are_valid(sequence):
|
||||
return (
|
||||
isinstance(sequence[0], list)
|
||||
and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in sequence[0])
|
||||
and isinstance(sequence[1], float)
|
||||
)
|
||||
|
||||
if isinstance(sequence_bias, list) and any(
|
||||
(not all_token_bias_pairs_are_valid(sequence)) or len(sequence) == 0 for sequence in sequence_bias
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is "
|
||||
f"{sequence_bias}."
|
||||
)
|
||||
if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()):
|
||||
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
|
||||
|
||||
def _convert_list_arguments_into_dict(self):
|
||||
"""BC: we used to accept `dict{tuple of tokens: float}` directly, now we expect a list"""
|
||||
if isinstance(self.sequence_bias, list):
|
||||
temp_sequence = self.sequence_bias
|
||||
self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence}
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
||||
"""
|
||||
|
@ -23,9 +23,41 @@ from pathlib import Path
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
from transformers.generation import GenerationMode
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerationMode,
|
||||
HammingDiversityLogitsProcessor,
|
||||
MinLengthLogitsProcessor,
|
||||
MinNewTokensLengthLogitsProcessor,
|
||||
MinPLogitsWarper,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, torch_device
|
||||
|
||||
|
||||
class GenerationConfigTest(unittest.TestCase):
|
||||
@ -225,6 +257,417 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||
|
||||
|
||||
class GenerationConfigSerializationTest(unittest.TestCase):
|
||||
def test_serialize_generation_sequence_bias(self):
|
||||
"""Tests that GenerationConfig is serialized and SequenceBiasLogitsProcessor is initialized with sequence_bias parameter"""
|
||||
generation_config = GenerationConfig()
|
||||
sequence_bias = [[[45, 67], -0.6], [[89], 1.2]]
|
||||
generation_config.sequence_bias = sequence_bias
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.sequence_bias, sequence_bias)
|
||||
|
||||
expected_sequence_bias = {(45, 67): -0.6, (89,): 1.2}
|
||||
bias_logits_processor = SequenceBiasLogitsProcessor(new_config.sequence_bias)
|
||||
self.assertDictEqual(bias_logits_processor.sequence_bias, expected_sequence_bias)
|
||||
|
||||
def test_serialize_generation_min_length_eos_token(self):
|
||||
"""Tests that GenerationConfig is serialized and MinLengthLogitsProcessor is initialized with min_length and eos_token_id"""
|
||||
eos_token_id = 0
|
||||
min_length = 10
|
||||
|
||||
generation_config = GenerationConfig(min_length=min_length, eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_length, min_length)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(
|
||||
min_length=new_config.min_length, eos_token_id=new_config.eos_token_id
|
||||
)
|
||||
self.assertEqual(min_dist_processor.min_length, min_length)
|
||||
self.assertEqual(min_dist_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_min_new_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and MinNewTokensLengthLogitsProcessor is initialized with min_new_tokens"""
|
||||
eos_token_id = 0
|
||||
min_new_tokens = 5
|
||||
prompt_length_to_skip = 2
|
||||
|
||||
generation_config = GenerationConfig(min_new_tokens=min_new_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_new_tokens, min_new_tokens)
|
||||
|
||||
min_new_tokens_processor = MinNewTokensLengthLogitsProcessor(
|
||||
prompt_length_to_skip=prompt_length_to_skip,
|
||||
min_new_tokens=new_config.min_new_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
self.assertEqual(min_new_tokens_processor.min_new_tokens, min_new_tokens)
|
||||
|
||||
def test_serialize_generation_temperature(self):
|
||||
"""Tests that GenerationConfig is serialized and TemperatureLogitsWarper is initialized with temperature"""
|
||||
temperature = 2.0
|
||||
|
||||
generation_config = GenerationConfig(temperature=temperature, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.temperature, temperature)
|
||||
|
||||
temperature_logits_warper = TemperatureLogitsWarper(temperature=new_config.temperature)
|
||||
self.assertEqual(temperature_logits_warper.temperature, temperature)
|
||||
|
||||
def test_serialize_generation_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and RepetitionPenaltyLogitsProcessor is initialized with repetition_penalty"""
|
||||
penalty = 2.0
|
||||
|
||||
generation_config = GenerationConfig(repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=new_config.repetition_penalty)
|
||||
self.assertEqual(rep_penalty_proc.penalty, penalty)
|
||||
|
||||
def test_serialize_generation_encoder_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderRepetitionPenaltyLogitsProcessor is initialized with penalty and input_ids"""
|
||||
penalty = 2.0
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(
|
||||
penalty=new_config.encoder_repetition_penalty, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(rep_penalty_proc.penalty, 1 / penalty)
|
||||
torch.testing.assert_close(rep_penalty_proc.encoder_input_ids, input_ids)
|
||||
|
||||
def test_serialize_generation_top_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TopPLogitsWarper is initialized with top_p"""
|
||||
top_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(top_p=top_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_p, top_p)
|
||||
|
||||
rep_penalty_proc = TopPLogitsWarper(top_p=new_config.top_p)
|
||||
self.assertEqual(rep_penalty_proc.top_p, top_p)
|
||||
|
||||
def test_serialize_generation_top_k(self):
|
||||
"""Tests that GenerationConfig is serialized and TopKLogitsWarper is initialized with top_k"""
|
||||
top_k = 2
|
||||
|
||||
generation_config = GenerationConfig(top_k=top_k, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_k, top_k)
|
||||
|
||||
top_k_logits_wrap = TopKLogitsWarper(top_k=new_config.top_k)
|
||||
self.assertEqual(top_k_logits_wrap.top_k, top_k)
|
||||
|
||||
def test_serialize_generation_min_p(self):
|
||||
"""Tests that GenerationConfig is serialized and MinPLogitsWarper is initialized with min_p"""
|
||||
min_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(min_p=min_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_p, min_p)
|
||||
|
||||
min_k_logits_wrap = MinPLogitsWarper(min_p=new_config.min_p)
|
||||
self.assertEqual(min_k_logits_wrap.min_p, min_p)
|
||||
|
||||
def test_serialize_generation_typical_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TypicalLogitsWarper is initialized with mass"""
|
||||
mass = 0.8
|
||||
|
||||
generation_config = GenerationConfig(typical_p=mass, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.typical_p, mass)
|
||||
|
||||
typical_p_logits_wrap = TypicalLogitsWarper(mass=new_config.typical_p)
|
||||
self.assertEqual(typical_p_logits_wrap.mass, mass)
|
||||
|
||||
def test_serialize_generation_epsilon_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EpsilonLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(epsilon_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.epsilon_cutoff, epsilon)
|
||||
|
||||
epsilon_logits_wrap = EpsilonLogitsWarper(epsilon=new_config.epsilon_cutoff)
|
||||
self.assertEqual(epsilon_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_eta_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EtaLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(eta_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eta_cutoff, epsilon)
|
||||
|
||||
eta_logits_wrap = EtaLogitsWarper(epsilon=new_config.eta_cutoff)
|
||||
self.assertEqual(eta_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and NoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
|
||||
generation_config = GenerationConfig(no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.no_repeat_ngram_size, ngram_size)
|
||||
|
||||
no_repeat_ngram_proc = NoRepeatNGramLogitsProcessor(ngram_size=new_config.no_repeat_ngram_size)
|
||||
self.assertEqual(no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_encoder_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderNoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_no_repeat_ngram_size, ngram_size)
|
||||
|
||||
encoder_no_repeat_ngram_proc = EncoderNoRepeatNGramLogitsProcessor(
|
||||
encoder_ngram_size=new_config.encoder_no_repeat_ngram_size, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(encoder_no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_bad_words_ids(self):
|
||||
"""Tests that GenerationConfig is serialized and NoBadWordsLogitsProcessor is initialized with bad_words_ids"""
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
|
||||
generation_config = GenerationConfig(bad_words_ids=bad_word_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.bad_words_ids, bad_word_tokens)
|
||||
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=new_config.bad_words_ids)
|
||||
self.assertSequenceEqual(no_bad_words_dist_proc.bad_word_ids, bad_word_tokens)
|
||||
|
||||
def test_serialize_generation_num_beams(self):
|
||||
"""Tests that GenerationConfig is serialized and PrefixConstrainedLogitsProcessor is initialized with num_beams"""
|
||||
num_beams = 1
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||
return [[0, 1], [2, 3]][batch_id]
|
||||
|
||||
generation_config = GenerationConfig(num_beams=num_beams)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.num_beams, num_beams)
|
||||
|
||||
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(
|
||||
prefix_allowed_tokens_fn, num_beams=new_config.num_beams
|
||||
)
|
||||
self.assertEqual(prefix_constrained_logits_proc._num_beams, num_beams)
|
||||
|
||||
def test_serialize_generation_diversity_penalty_and_num_bean_groups(self):
|
||||
"""Tests that GenerationConfig is serialized and HammingDiversityLogitsProcessor is initialized with diversity_penalty_and_num_bean_groups"""
|
||||
num_beams = 2
|
||||
num_beam_groups = 2
|
||||
diversity_penalty = 1.0
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
num_beams=num_beams, diversity_penalty=diversity_penalty, num_beam_groups=num_beam_groups
|
||||
)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.num_beams, num_beams)
|
||||
self.assertEqual(new_config.diversity_penalty, diversity_penalty)
|
||||
self.assertEqual(new_config.num_beam_groups, num_beam_groups)
|
||||
|
||||
diversity_logits_processor = HammingDiversityLogitsProcessor(
|
||||
diversity_penalty=new_config.diversity_penalty,
|
||||
num_beams=new_config.num_beams,
|
||||
num_beam_groups=new_config.num_beam_groups,
|
||||
)
|
||||
self.assertEqual(diversity_logits_processor._num_beams, num_beams)
|
||||
self.assertEqual(diversity_logits_processor._diversity_penalty, diversity_penalty)
|
||||
self.assertEqual(diversity_logits_processor._num_sub_beams, num_beams // num_beam_groups)
|
||||
|
||||
def test_serialize_generation_bos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedBOSTokenLogitsProcessor is initialized with bos_token_id"""
|
||||
bos_token_id = 0
|
||||
|
||||
generation_config = GenerationConfig(bos_token_id=bos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.bos_token_id, bos_token_id)
|
||||
|
||||
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=new_config.bos_token_id)
|
||||
self.assertEqual(logits_processor.bos_token_id, bos_token_id)
|
||||
|
||||
def test_serialize_generation_eos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedEOSTokenLogitsProcessor is initialized with eos_token_id"""
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
generation_config = GenerationConfig(eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(
|
||||
max_length=max_length, eos_token_id=new_config.eos_token_id, device=torch_device
|
||||
)
|
||||
self.assertEqual(logits_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_exponential_decay_length_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and ExponentialDecayLengthPenalty is initialized with regulation_start and regulation_factor"""
|
||||
eos_token_id = 0
|
||||
penalty_start = 5
|
||||
penalty_factor = 1.1
|
||||
input_ids_seq_length = 10
|
||||
exponential_decay_length_penalty = (penalty_start, penalty_factor)
|
||||
|
||||
generation_config = GenerationConfig(exponential_decay_length_penalty=exponential_decay_length_penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.exponential_decay_length_penalty, [penalty_start, penalty_factor])
|
||||
|
||||
exponential_decay_processor = ExponentialDecayLengthPenalty(
|
||||
exponential_decay_length_penalty=new_config.exponential_decay_length_penalty,
|
||||
eos_token_id=eos_token_id,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
)
|
||||
self.assertEqual(
|
||||
exponential_decay_processor.regulation_start, exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
)
|
||||
self.assertEqual(exponential_decay_processor.regulation_factor, exponential_decay_length_penalty[1])
|
||||
|
||||
def test_serialize_generation_begin_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensAtBeginLogitsProcessor is initialized with begin_suppress_token and begin_index"""
|
||||
|
||||
begin_suppress_tokens = [220, 50256]
|
||||
begin_index = 0
|
||||
generation_config = GenerationConfig(begin_suppress_tokens=begin_suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.begin_suppress_tokens, begin_suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=new_config.begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
self.assertSequenceEqual(suppress_processor.begin_suppress_tokens, begin_suppress_tokens)
|
||||
self.assertEqual(suppress_processor.begin_index, begin_index)
|
||||
|
||||
def test_serialize_generation_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensLogitsProcessor is initialized with suppress_token"""
|
||||
suppress_tokens = [220, 50256]
|
||||
|
||||
generation_config = GenerationConfig(suppress_tokens=suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens)
|
||||
self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens)
|
||||
|
||||
def test_serialize_generation_guidance_scale(self):
|
||||
"""Tests that GenerationConfig is serialized and ClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
classifier_processor = ClassifierFreeGuidanceLogitsProcessor(guidance_scale=new_config.guidance_scale)
|
||||
self.assertEqual(classifier_processor.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_guidance_scale_unbatched(self):
|
||||
"""Tests that GenerationConfig is serialized and UnbatchedClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(new_config.guidance_scale, {}, input_ids)
|
||||
self.assertEqual(cfg.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_watermarking_config(self):
|
||||
"""Tests that GenerationConfig is serialized and WatermarkLogitsProcessor is initialized with WatermarkingConfig parameters"""
|
||||
|
||||
vocab_size = 20
|
||||
bias = 2.0
|
||||
greenlist_ratio = 0.5
|
||||
hashing_key = 10
|
||||
seeding_scheme = "lefthash"
|
||||
context_width = 10
|
||||
watermarking_config = WatermarkingConfig(
|
||||
bias=bias,
|
||||
greenlist_ratio=greenlist_ratio,
|
||||
hashing_key=hashing_key,
|
||||
seeding_scheme=seeding_scheme,
|
||||
context_width=context_width,
|
||||
)
|
||||
generation_config = GenerationConfig(watermarking_config=watermarking_config)
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.watermarking_config.bias, bias)
|
||||
self.assertEqual(new_config.watermarking_config.greenlist_ratio, greenlist_ratio)
|
||||
self.assertEqual(new_config.watermarking_config.hashing_key, hashing_key)
|
||||
self.assertEqual(new_config.watermarking_config.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(new_config.watermarking_config.context_width, context_width)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(
|
||||
vocab_size=vocab_size,
|
||||
device=torch_device,
|
||||
greenlist_ratio=new_config.watermarking_config.greenlist_ratio,
|
||||
bias=new_config.watermarking_config.bias,
|
||||
hashing_key=new_config.watermarking_config.hashing_key,
|
||||
seeding_scheme=new_config.watermarking_config.seeding_scheme,
|
||||
context_width=new_config.watermarking_config.context_width,
|
||||
)
|
||||
self.assertEqual(watermark.bias, bias)
|
||||
self.assertEqual(watermark.greenlist_size, int(vocab_size * greenlist_ratio))
|
||||
self.assertEqual(watermark.hash_key, hashing_key)
|
||||
self.assertEqual(watermark.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(watermark.context_width, context_width)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user