mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Dynamic number of speculative tokens in order to accelerate speculative decoding (#33258)
* optimal Speculation Lookahead based on probability * update peer finished condition * add support to do_sample True * add stopping criteria * gitignore * add print * remove prints * minor * minor * git ignore * adding test to stopping ConfidenceCriteria * doc + format * add doc * Update .gitignore * update docstring and default value of assistant_confidence_threshold * add docstring * Update src/transformers/generation/configuration_utils.py implicit default value (None) Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * style fix --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
42babe8548
commit
7a51cbc65f
@ -83,6 +83,7 @@ else:
|
||||
"MaxNewTokensCriteria",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"ConfidenceCriteria",
|
||||
"EosTokenCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
@ -225,6 +226,7 @@ if TYPE_CHECKING:
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
|
@ -108,6 +108,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# Prepare the assistant and the starting number of candidate tokens
|
||||
self.assistant_model = assistant_model
|
||||
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||
self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
|
||||
|
||||
# Set eos in assistant same as in target model
|
||||
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
|
||||
@ -157,6 +158,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
self.generation_config = copy.deepcopy(generation_config)
|
||||
self.generation_config.return_dict_in_generate = True
|
||||
self.generation_config.output_scores = True
|
||||
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
|
||||
|
||||
# Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant
|
||||
# greedily to maximize matches. Disables sampling-related flags to prevent warnings
|
||||
|
@ -350,6 +350,11 @@ class GenerationConfig(PushToHubMixin):
|
||||
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
||||
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
||||
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
||||
assistant_confidence_threshold (`float`, *optional*):
|
||||
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
||||
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
||||
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
|
||||
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
|
||||
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
|
||||
The number of tokens to be output as candidate tokens.
|
||||
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
||||
@ -449,6 +454,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Assistant generation
|
||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)
|
||||
|
||||
# Prompt lookup decoding
|
||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||
|
@ -467,6 +467,27 @@ class EosTokenCriteria(StoppingCriteria):
|
||||
return is_done
|
||||
|
||||
|
||||
class ConfidenceCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
|
||||
`model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
|
||||
|
||||
Args:
|
||||
assistant_confidence_threshold (`float`):
|
||||
The value of the threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, assistant_confidence_threshold):
|
||||
self.assistant_confidence_threshold = assistant_confidence_threshold
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
probs = scores[-1].softmax(-1)
|
||||
p = probs[0, input_ids[0, -1]].item()
|
||||
if p < self.assistant_confidence_threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
|
@ -97,6 +97,7 @@ from .logits_process import (
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
@ -958,6 +959,13 @@ class GenerationMixin:
|
||||
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
|
||||
if generation_config._eos_token_tensor is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
||||
if (
|
||||
generation_config.assistant_confidence_threshold is not None
|
||||
and generation_config.assistant_confidence_threshold > 0
|
||||
):
|
||||
criteria.append(
|
||||
ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
|
||||
)
|
||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||
return criteria
|
||||
|
||||
|
@ -26,6 +26,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
@ -100,6 +101,23 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_confidence_criteria(self):
|
||||
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
|
||||
|
||||
vocab_size = 250
|
||||
length = 5
|
||||
|
||||
input_ids = ids_tensor((1, length), vocab_size)
|
||||
scores = (torch.randn((1, vocab_size)),)
|
||||
|
||||
# Simulate high confidence by setting the probability of the last token to be high
|
||||
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
# Simulate low confidence by setting the probability of the last token to be low
|
||||
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user