mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adaptive dynamic number of speculative tokens (#34156)
* initial commit * update strategy * add tradeoff FPR TPR with cost * all probs * fix * fix * fix style * Update src/transformers/generation/configuration_utils.py shorter docstring Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * import guard * fix style * add is_sklearn_available condition * vectorizing to flatten the for-loop * fix style * disable adaptation for UAG * update doc * add TestAssistedCandidateGeneratorUpdateStrategy * fix style * protect import * fix style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
b0a51e5cff
commit
e27465c801
@ -456,6 +456,8 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
||||
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
|
||||
```
|
||||
|
||||
We recommend to install `scikit-learn` library to enhance the candidate generation strategy and achieve additional speedup.
|
||||
|
||||
#### Universal Assisted Decoding
|
||||
|
||||
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
|
||||
|
@ -19,6 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import is_sklearn_available
|
||||
|
||||
|
||||
if is_sklearn_available():
|
||||
from sklearn.metrics import roc_curve
|
||||
|
||||
from ..cache_utils import DynamicCache
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
||||
@ -180,6 +186,14 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# We need to roll back the cache in assisted generation, only DynamicCache is supported
|
||||
self.generation_config.cache_implementation = None
|
||||
|
||||
if (
|
||||
is_sklearn_available()
|
||||
and self.assistant_model.generation_config.assistant_confidence_threshold
|
||||
and type(self) is AssistedCandidateGenerator
|
||||
):
|
||||
self.probs = []
|
||||
self.matches = []
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
@ -230,6 +244,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# 3. Update variables for the next round of candidate generation
|
||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||
|
||||
if (
|
||||
is_sklearn_available()
|
||||
and self.assistant_model.generation_config.assistant_confidence_threshold
|
||||
and type(self) is AssistedCandidateGenerator
|
||||
):
|
||||
scores_tensor = torch.cat(assistant_output.scores, dim=0)
|
||||
scores_softmax = torch.softmax(scores_tensor, dim=-1)
|
||||
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
|
||||
p = scores_softmax[range(len(ids)), ids]
|
||||
self.probs.extend(p.tolist())
|
||||
|
||||
# 4. Prepare variables for output
|
||||
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
||||
candidate_ids = assistant_output.sequences
|
||||
@ -261,6 +286,38 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
else:
|
||||
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
|
||||
|
||||
# The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
|
||||
# This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
|
||||
if (
|
||||
is_sklearn_available()
|
||||
and self.assistant_model.generation_config.assistant_confidence_threshold
|
||||
and type(self) is AssistedCandidateGenerator
|
||||
):
|
||||
# update self.matches
|
||||
self.matches.extend([1] * num_matches)
|
||||
if len(self.probs) > len(self.matches):
|
||||
self.matches.append(0)
|
||||
|
||||
# update self.probs
|
||||
excess_length = len(self.probs) - len(self.matches)
|
||||
if excess_length > 0:
|
||||
del self.probs[-excess_length:]
|
||||
|
||||
if (
|
||||
len(self.probs) > 5 and {0, 1}.issubset(self.matches)
|
||||
): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
|
||||
fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
|
||||
fnr = 1 - tpr
|
||||
|
||||
# Calculate the cost for each threshold
|
||||
costs = fpr + 3 * fnr
|
||||
|
||||
# Find the threshold that minimizes the cost
|
||||
optimal_threshold_index = np.argmin(costs)
|
||||
best_threshold = thresholds[optimal_threshold_index]
|
||||
|
||||
self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
|
||||
|
||||
|
||||
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
"""
|
||||
|
@ -353,7 +353,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
|
||||
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
||||
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
||||
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
|
||||
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
|
||||
`assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
|
||||
It is an unsupervised version of the dynamic speculation lookahead
|
||||
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
|
||||
prompt_lookup_num_tokens (`int`, *optional*):
|
||||
The number of tokens to be output as candidate tokens.
|
||||
|
@ -92,9 +92,16 @@ if is_torch_available():
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||
from transformers.generation.candidate_generator import (
|
||||
AssistedCandidateGenerator,
|
||||
AssistedCandidateGeneratorDifferentTokenizers,
|
||||
)
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.utils import is_sklearn_available
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
input_name = "input_ids"
|
||||
@ -4312,3 +4319,110 @@ class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
|
||||
class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||
self.assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
self.assistant_model.generation_config.assistant_confidence_threshold = 0.4
|
||||
self.model_kwargs = {}
|
||||
self.input_ids = torch.randint(1, 10, (1, 9))
|
||||
self.candidate_generator = AssistedCandidateGenerator(
|
||||
input_ids=self.input_ids,
|
||||
assistant_model=self.assistant_model,
|
||||
generation_config=self.assistant_model.generation_config,
|
||||
model_kwargs=self.model_kwargs,
|
||||
)
|
||||
self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
|
||||
self.original_probs = self.candidate_generator.probs
|
||||
self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold
|
||||
|
||||
def assert_no_sklearn(self):
|
||||
with patch("transformers.utils.import_utils._sklearn_available", False):
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, self.original_matches)
|
||||
self.assertEqual(self.candidate_generator.probs, self.original_probs)
|
||||
self.assertEqual(
|
||||
self.assistant_model.generation_config.assistant_confidence_threshold, self.original_threshold
|
||||
)
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
|
||||
print("test_update_candidate_strategy_no_matches_short")
|
||||
self.original_matches = []
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 0
|
||||
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available):
|
||||
self.original_matches = [1, 0, 1, 0, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 3
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_4(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 4
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_3(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 3
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_2(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 2
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_with_matches_1(self, sklearn_available):
|
||||
self.original_matches = [1, 1, 1, 1, 1]
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 1
|
||||
if sklearn_available:
|
||||
self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches)
|
||||
self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0])
|
||||
self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
Loading…
Reference in New Issue
Block a user