Adaptive dynamic number of speculative tokens (#34156)

* initial commit

* update strategy

* add tradeoff FPR TPR with cost

* all probs

* fix

* fix

* fix style

* Update src/transformers/generation/configuration_utils.py

shorter docstring

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* import guard

* fix style

* add is_sklearn_available condition

* vectorizing to flatten the for-loop

* fix style

* disable adaptation for UAG

* update doc

* add TestAssistedCandidateGeneratorUpdateStrategy

* fix style

* protect import

* fix style

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonathan Mamou 2024-12-05 18:07:33 +02:00 committed by GitHub
parent b0a51e5cff
commit e27465c801
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 177 additions and 2 deletions

View File

@ -456,6 +456,8 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
```
We recommend to install `scikit-learn` library to enhance the candidate generation strategy and achieve additional speedup.
#### Universal Assisted Decoding
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.

View File

@ -19,6 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import numpy as np
import torch
from ..utils import is_sklearn_available
if is_sklearn_available():
from sklearn.metrics import roc_curve
from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
@ -180,6 +186,14 @@ class AssistedCandidateGenerator(CandidateGenerator):
# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
self.probs = []
self.matches = []
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
@ -230,6 +244,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())
# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
@ -261,6 +286,38 @@ class AssistedCandidateGenerator(CandidateGenerator):
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
# The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
# This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
# update self.matches
self.matches.extend([1] * num_matches)
if len(self.probs) > len(self.matches):
self.matches.append(0)
# update self.probs
excess_length = len(self.probs) - len(self.matches)
if excess_length > 0:
del self.probs[-excess_length:]
if (
len(self.probs) > 5 and {0, 1}.issubset(self.matches)
): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
fnr = 1 - tpr
# Calculate the cost for each threshold
costs = fpr + 3 * fnr
# Find the threshold that minimizes the cost
optimal_threshold_index = np.argmin(costs)
best_threshold = thresholds[optimal_threshold_index]
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""

View File

@ -353,7 +353,9 @@ class GenerationConfig(PushToHubMixin):
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
`assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*):
The number of tokens to be output as candidate tokens.

View File

@ -92,9 +92,16 @@ if is_torch_available():
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
from transformers.generation.candidate_generator import (
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
)
from transformers.generation.utils import _speculative_sampling
from unittest.mock import patch
from transformers.utils import is_sklearn_available
class GenerationTesterMixin:
input_name = "input_ids"
@ -4312,3 +4319,110 @@ class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
def setUp(self):
checkpoint = "EleutherAI/pythia-160m-deduped"
self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
self.assistant_model.generation_config.assistant_confidence_threshold = 0.4
self.model_kwargs = {}
self.input_ids = torch.randint(1, 10, (1, 9))
self.candidate_generator = AssistedCandidateGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
generation_config=self.assistant_model.generation_config,
model_kwargs=self.model_kwargs,
)
self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
self.original_probs = self.candidate_generator.probs
self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold
def assert_no_sklearn(self):
with patch("transformers.utils.import_utils._sklearn_available", False):
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, self.original_matches)
self.assertEqual(self.candidate_generator.probs, self.original_probs)
self.assertEqual(
self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold
)
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
print("test_update_candidate_strategy_no_matches_short")
self.original_matches = []
self.candidate_generator.matches = self.original_matches
self.num_matches = 0
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [0])
self.assertEqual(self.candidate_generator.probs, [0.9])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available):
self.original_matches = [1, 0, 1, 0, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 3
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_4(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 4
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_3(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 3
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_2(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 2
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3)
else:
self.assert_no_sklearn()
@parameterized.expand([(is_sklearn_available(),), (False,)])
def test_update_candidate_strategy_with_matches_1(self, sklearn_available):
self.original_matches = [1, 1, 1, 1, 1]
self.candidate_generator.matches = self.original_matches
self.num_matches = 1
if sklearn_available:
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0])
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()