change sequence_bias type of SequenceBiasLogitsProcessor to list, add… (#33375)

* change sequence_bias type of SequenceBiasLogitsProcessor tp list, add config tests for all processors

* fix format

* small fix for all_token_bias_pairs_are_valid internal func

* small typo fix in description

* improve test impl, some SequenceBiasLogitsProcessor refactoring
This commit is contained in:
Vladislav Bronzov 2024-09-19 18:35:44 +02:00 committed by GitHub
parent d9d59e7bac
commit 162056a3f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 486 additions and 16 deletions

View File

@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
@ -1064,8 +1064,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
</Tip>
Args:
sequence_bias (`Dict[Tuple[int], float]`):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence_bias (`List[List[Union[List[int], float]]]`):
List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0],
[[64], -7.5]]`). Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
completed (in the token selection step after this processor is applied).
@ -1087,12 +1088,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
>>> def get_tokens(word):
... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> sequence_bias = [get_tokens("Trump"), -10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
@ -1102,16 +1103,17 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
The full name of Donald is Donald Rumsfeld,
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> sequence_bias = [get_tokens("Donald Duck"), 10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
def __init__(self, sequence_bias: List[List[Union[List[int], float]]]):
self.sequence_bias = sequence_bias
self._validate_arguments()
self._convert_list_arguments_into_dict()
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
@ -1178,11 +1180,15 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
def _validate_arguments(self):
sequence_bias = self.sequence_bias
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0:
raise ValueError(
f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(
not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
):
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
if any(
if isinstance(sequence_bias, dict) and any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
or len(sequence_ids) == 0
for sequence_ids in sequence_bias.keys()
@ -1191,9 +1197,30 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f"{sequence_bias}."
)
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
def all_token_bias_pairs_are_valid(sequence):
return (
isinstance(sequence[0], list)
and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in sequence[0])
and isinstance(sequence[1], float)
)
if isinstance(sequence_bias, list) and any(
(not all_token_bias_pairs_are_valid(sequence)) or len(sequence) == 0 for sequence in sequence_bias
):
raise ValueError(
f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is "
f"{sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()):
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
def _convert_list_arguments_into_dict(self):
"""BC: we used to accept `dict{tuple of tokens: float}` directly, now we expect a list"""
if isinstance(self.sequence_bias, list):
temp_sequence = self.sequence_bias
self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence}
class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
"""

View File

@ -23,9 +23,41 @@ from pathlib import Path
from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode
from transformers.testing_utils import TOKEN, USER, is_staging_test
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
if is_torch_available():
import torch
from transformers.generation import (
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerationMode,
HammingDiversityLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
MinPLogitsWarper,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, torch_device
class GenerationConfigTest(unittest.TestCase):
@ -225,6 +257,417 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
class GenerationConfigSerializationTest(unittest.TestCase):
def test_serialize_generation_sequence_bias(self):
"""Tests that GenerationConfig is serialized and SequenceBiasLogitsProcessor is initialized with sequence_bias parameter"""
generation_config = GenerationConfig()
sequence_bias = [[[45, 67], -0.6], [[89], 1.2]]
generation_config.sequence_bias = sequence_bias
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertSequenceEqual(new_config.sequence_bias, sequence_bias)
expected_sequence_bias = {(45, 67): -0.6, (89,): 1.2}
bias_logits_processor = SequenceBiasLogitsProcessor(new_config.sequence_bias)
self.assertDictEqual(bias_logits_processor.sequence_bias, expected_sequence_bias)
def test_serialize_generation_min_length_eos_token(self):
"""Tests that GenerationConfig is serialized and MinLengthLogitsProcessor is initialized with min_length and eos_token_id"""
eos_token_id = 0
min_length = 10
generation_config = GenerationConfig(min_length=min_length, eos_token_id=eos_token_id)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.min_length, min_length)
self.assertEqual(new_config.eos_token_id, eos_token_id)
min_dist_processor = MinLengthLogitsProcessor(
min_length=new_config.min_length, eos_token_id=new_config.eos_token_id
)
self.assertEqual(min_dist_processor.min_length, min_length)
self.assertEqual(min_dist_processor.eos_token_id, eos_token_id)
def test_serialize_generation_min_new_tokens(self):
"""Tests that GenerationConfig is serialized and MinNewTokensLengthLogitsProcessor is initialized with min_new_tokens"""
eos_token_id = 0
min_new_tokens = 5
prompt_length_to_skip = 2
generation_config = GenerationConfig(min_new_tokens=min_new_tokens)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.min_new_tokens, min_new_tokens)
min_new_tokens_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=prompt_length_to_skip,
min_new_tokens=new_config.min_new_tokens,
eos_token_id=eos_token_id,
)
self.assertEqual(min_new_tokens_processor.min_new_tokens, min_new_tokens)
def test_serialize_generation_temperature(self):
"""Tests that GenerationConfig is serialized and TemperatureLogitsWarper is initialized with temperature"""
temperature = 2.0
generation_config = GenerationConfig(temperature=temperature, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.temperature, temperature)
temperature_logits_warper = TemperatureLogitsWarper(temperature=new_config.temperature)
self.assertEqual(temperature_logits_warper.temperature, temperature)
def test_serialize_generation_repetition_penalty(self):
"""Tests that GenerationConfig is serialized and RepetitionPenaltyLogitsProcessor is initialized with repetition_penalty"""
penalty = 2.0
generation_config = GenerationConfig(repetition_penalty=penalty)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.repetition_penalty, penalty)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=new_config.repetition_penalty)
self.assertEqual(rep_penalty_proc.penalty, penalty)
def test_serialize_generation_encoder_repetition_penalty(self):
"""Tests that GenerationConfig is serialized and EncoderRepetitionPenaltyLogitsProcessor is initialized with penalty and input_ids"""
penalty = 2.0
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
generation_config = GenerationConfig(encoder_repetition_penalty=penalty)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.encoder_repetition_penalty, penalty)
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(
penalty=new_config.encoder_repetition_penalty, encoder_input_ids=input_ids
)
self.assertEqual(rep_penalty_proc.penalty, 1 / penalty)
torch.testing.assert_close(rep_penalty_proc.encoder_input_ids, input_ids)
def test_serialize_generation_top_p(self):
"""Tests that GenerationConfig is serialized and TopPLogitsWarper is initialized with top_p"""
top_p = 0.8
generation_config = GenerationConfig(top_p=top_p, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.top_p, top_p)
rep_penalty_proc = TopPLogitsWarper(top_p=new_config.top_p)
self.assertEqual(rep_penalty_proc.top_p, top_p)
def test_serialize_generation_top_k(self):
"""Tests that GenerationConfig is serialized and TopKLogitsWarper is initialized with top_k"""
top_k = 2
generation_config = GenerationConfig(top_k=top_k, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.top_k, top_k)
top_k_logits_wrap = TopKLogitsWarper(top_k=new_config.top_k)
self.assertEqual(top_k_logits_wrap.top_k, top_k)
def test_serialize_generation_min_p(self):
"""Tests that GenerationConfig is serialized and MinPLogitsWarper is initialized with min_p"""
min_p = 0.8
generation_config = GenerationConfig(min_p=min_p, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.min_p, min_p)
min_k_logits_wrap = MinPLogitsWarper(min_p=new_config.min_p)
self.assertEqual(min_k_logits_wrap.min_p, min_p)
def test_serialize_generation_typical_p(self):
"""Tests that GenerationConfig is serialized and TypicalLogitsWarper is initialized with mass"""
mass = 0.8
generation_config = GenerationConfig(typical_p=mass, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.typical_p, mass)
typical_p_logits_wrap = TypicalLogitsWarper(mass=new_config.typical_p)
self.assertEqual(typical_p_logits_wrap.mass, mass)
def test_serialize_generation_epsilon_cutoff(self):
"""Tests that GenerationConfig is serialized and EpsilonLogitsWarper is initialized with epsilon"""
epsilon = 0.8
generation_config = GenerationConfig(epsilon_cutoff=epsilon, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.epsilon_cutoff, epsilon)
epsilon_logits_wrap = EpsilonLogitsWarper(epsilon=new_config.epsilon_cutoff)
self.assertEqual(epsilon_logits_wrap.epsilon, epsilon)
def test_serialize_generation_eta_cutoff(self):
"""Tests that GenerationConfig is serialized and EtaLogitsWarper is initialized with epsilon"""
epsilon = 0.8
generation_config = GenerationConfig(eta_cutoff=epsilon, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.eta_cutoff, epsilon)
eta_logits_wrap = EtaLogitsWarper(epsilon=new_config.eta_cutoff)
self.assertEqual(eta_logits_wrap.epsilon, epsilon)
def test_serialize_generation_ngram_size(self):
"""Tests that GenerationConfig is serialized and NoRepeatNGramLogitsProcessor is initialized with ngram_size"""
ngram_size = 2
generation_config = GenerationConfig(no_repeat_ngram_size=ngram_size, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.no_repeat_ngram_size, ngram_size)
no_repeat_ngram_proc = NoRepeatNGramLogitsProcessor(ngram_size=new_config.no_repeat_ngram_size)
self.assertEqual(no_repeat_ngram_proc.ngram_size, ngram_size)
def test_serialize_generation_encoder_ngram_size(self):
"""Tests that GenerationConfig is serialized and EncoderNoRepeatNGramLogitsProcessor is initialized with ngram_size"""
ngram_size = 2
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=ngram_size, do_sample=True)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.encoder_no_repeat_ngram_size, ngram_size)
encoder_no_repeat_ngram_proc = EncoderNoRepeatNGramLogitsProcessor(
encoder_ngram_size=new_config.encoder_no_repeat_ngram_size, encoder_input_ids=input_ids
)
self.assertEqual(encoder_no_repeat_ngram_proc.ngram_size, ngram_size)
def test_serialize_generation_bad_words_ids(self):
"""Tests that GenerationConfig is serialized and NoBadWordsLogitsProcessor is initialized with bad_words_ids"""
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
generation_config = GenerationConfig(bad_words_ids=bad_word_tokens)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertSequenceEqual(new_config.bad_words_ids, bad_word_tokens)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=new_config.bad_words_ids)
self.assertSequenceEqual(no_bad_words_dist_proc.bad_word_ids, bad_word_tokens)
def test_serialize_generation_num_beams(self):
"""Tests that GenerationConfig is serialized and PrefixConstrainedLogitsProcessor is initialized with num_beams"""
num_beams = 1
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]
generation_config = GenerationConfig(num_beams=num_beams)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.num_beams, num_beams)
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, num_beams=new_config.num_beams
)
self.assertEqual(prefix_constrained_logits_proc._num_beams, num_beams)
def test_serialize_generation_diversity_penalty_and_num_bean_groups(self):
"""Tests that GenerationConfig is serialized and HammingDiversityLogitsProcessor is initialized with diversity_penalty_and_num_bean_groups"""
num_beams = 2
num_beam_groups = 2
diversity_penalty = 1.0
generation_config = GenerationConfig(
num_beams=num_beams, diversity_penalty=diversity_penalty, num_beam_groups=num_beam_groups
)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.num_beams, num_beams)
self.assertEqual(new_config.diversity_penalty, diversity_penalty)
self.assertEqual(new_config.num_beam_groups, num_beam_groups)
diversity_logits_processor = HammingDiversityLogitsProcessor(
diversity_penalty=new_config.diversity_penalty,
num_beams=new_config.num_beams,
num_beam_groups=new_config.num_beam_groups,
)
self.assertEqual(diversity_logits_processor._num_beams, num_beams)
self.assertEqual(diversity_logits_processor._diversity_penalty, diversity_penalty)
self.assertEqual(diversity_logits_processor._num_sub_beams, num_beams // num_beam_groups)
def test_serialize_generation_bos_token_id(self):
"""Tests that GenerationConfig is serialized and ForcedBOSTokenLogitsProcessor is initialized with bos_token_id"""
bos_token_id = 0
generation_config = GenerationConfig(bos_token_id=bos_token_id)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.bos_token_id, bos_token_id)
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=new_config.bos_token_id)
self.assertEqual(logits_processor.bos_token_id, bos_token_id)
def test_serialize_generation_eos_token_id(self):
"""Tests that GenerationConfig is serialized and ForcedEOSTokenLogitsProcessor is initialized with eos_token_id"""
eos_token_id = 0
max_length = 5
generation_config = GenerationConfig(eos_token_id=eos_token_id)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.eos_token_id, eos_token_id)
logits_processor = ForcedEOSTokenLogitsProcessor(
max_length=max_length, eos_token_id=new_config.eos_token_id, device=torch_device
)
self.assertEqual(logits_processor.eos_token_id, eos_token_id)
def test_serialize_generation_exponential_decay_length_penalty(self):
"""Tests that GenerationConfig is serialized and ExponentialDecayLengthPenalty is initialized with regulation_start and regulation_factor"""
eos_token_id = 0
penalty_start = 5
penalty_factor = 1.1
input_ids_seq_length = 10
exponential_decay_length_penalty = (penalty_start, penalty_factor)
generation_config = GenerationConfig(exponential_decay_length_penalty=exponential_decay_length_penalty)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.exponential_decay_length_penalty, [penalty_start, penalty_factor])
exponential_decay_processor = ExponentialDecayLengthPenalty(
exponential_decay_length_penalty=new_config.exponential_decay_length_penalty,
eos_token_id=eos_token_id,
input_ids_seq_length=input_ids_seq_length,
)
self.assertEqual(
exponential_decay_processor.regulation_start, exponential_decay_length_penalty[0] + input_ids_seq_length
)
self.assertEqual(exponential_decay_processor.regulation_factor, exponential_decay_length_penalty[1])
def test_serialize_generation_begin_suppress_tokens(self):
"""Tests that GenerationConfig is serialized and SuppressTokensAtBeginLogitsProcessor is initialized with begin_suppress_token and begin_index"""
begin_suppress_tokens = [220, 50256]
begin_index = 0
generation_config = GenerationConfig(begin_suppress_tokens=begin_suppress_tokens)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertSequenceEqual(new_config.begin_suppress_tokens, begin_suppress_tokens)
suppress_processor = SuppressTokensAtBeginLogitsProcessor(
begin_suppress_tokens=new_config.begin_suppress_tokens, begin_index=begin_index
)
self.assertSequenceEqual(suppress_processor.begin_suppress_tokens, begin_suppress_tokens)
self.assertEqual(suppress_processor.begin_index, begin_index)
def test_serialize_generation_suppress_tokens(self):
"""Tests that GenerationConfig is serialized and SuppressTokensLogitsProcessor is initialized with suppress_token"""
suppress_tokens = [220, 50256]
generation_config = GenerationConfig(suppress_tokens=suppress_tokens)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens)
suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens)
self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens)
def test_serialize_generation_guidance_scale(self):
"""Tests that GenerationConfig is serialized and ClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
guidance_scale = 2.0
generation_config = GenerationConfig(guidance_scale=guidance_scale)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.guidance_scale, guidance_scale)
classifier_processor = ClassifierFreeGuidanceLogitsProcessor(guidance_scale=new_config.guidance_scale)
self.assertEqual(classifier_processor.guidance_scale, guidance_scale)
def test_serialize_generation_guidance_scale_unbatched(self):
"""Tests that GenerationConfig is serialized and UnbatchedClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
guidance_scale = 2.0
input_ids = torch.LongTensor([[0]])
generation_config = GenerationConfig(guidance_scale=guidance_scale)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.guidance_scale, guidance_scale)
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(new_config.guidance_scale, {}, input_ids)
self.assertEqual(cfg.guidance_scale, guidance_scale)
def test_serialize_generation_watermarking_config(self):
"""Tests that GenerationConfig is serialized and WatermarkLogitsProcessor is initialized with WatermarkingConfig parameters"""
vocab_size = 20
bias = 2.0
greenlist_ratio = 0.5
hashing_key = 10
seeding_scheme = "lefthash"
context_width = 10
watermarking_config = WatermarkingConfig(
bias=bias,
greenlist_ratio=greenlist_ratio,
hashing_key=hashing_key,
seeding_scheme=seeding_scheme,
context_width=context_width,
)
generation_config = GenerationConfig(watermarking_config=watermarking_config)
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
self.assertEqual(new_config.watermarking_config.bias, bias)
self.assertEqual(new_config.watermarking_config.greenlist_ratio, greenlist_ratio)
self.assertEqual(new_config.watermarking_config.hashing_key, hashing_key)
self.assertEqual(new_config.watermarking_config.seeding_scheme, seeding_scheme)
self.assertEqual(new_config.watermarking_config.context_width, context_width)
watermark = WatermarkLogitsProcessor(
vocab_size=vocab_size,
device=torch_device,
greenlist_ratio=new_config.watermarking_config.greenlist_ratio,
bias=new_config.watermarking_config.bias,
hashing_key=new_config.watermarking_config.hashing_key,
seeding_scheme=new_config.watermarking_config.seeding_scheme,
context_width=new_config.watermarking_config.context_width,
)
self.assertEqual(watermark.bias, bias)
self.assertEqual(watermark.greenlist_size, int(vocab_size * greenlist_ratio))
self.assertEqual(watermark.hash_key, hashing_key)
self.assertEqual(watermark.seeding_scheme, seeding_scheme)
self.assertEqual(watermark.context_width, context_width)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod