Dynamic number of speculative tokens in order to accelerate speculative decoding (#33258)

* optimal Speculation Lookahead based on probability

* update peer finished condition

* add support to do_sample True

* add stopping criteria

* gitignore

* add print

* remove prints

* minor

* minor

* git ignore

* adding test to stopping ConfidenceCriteria

* doc + format

* add doc

* Update .gitignore

* update docstring and default value of assistant_confidence_threshold

* add docstring

* Update src/transformers/generation/configuration_utils.py

implicit default value (None)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* style fix

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonathan Mamou 2024-09-11 15:22:28 +03:00 committed by GitHub
parent 42babe8548
commit 7a51cbc65f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 0 deletions

View File

@ -83,6 +83,7 @@ else:
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"ConfidenceCriteria",
"EosTokenCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
@ -225,6 +226,7 @@ if TYPE_CHECKING:
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,

View File

@ -108,6 +108,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
# Set eos in assistant same as in target model
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
@ -157,6 +158,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
# Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant
# greedily to maximize matches. Disables sampling-related flags to prevent warnings

View File

@ -350,6 +350,11 @@ class GenerationConfig(PushToHubMixin):
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
assistant_confidence_threshold (`float`, *optional*):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
The number of tokens to be output as candidate tokens.
max_matching_ngram_size (`int`, *optional*, default to `None`):
@ -449,6 +454,7 @@ class GenerationConfig(PushToHubMixin):
# Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)
# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

View File

@ -467,6 +467,27 @@ class EosTokenCriteria(StoppingCriteria):
return is_done
class ConfidenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
`model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
Args:
assistant_confidence_threshold (`float`):
The value of the threshold.
"""
def __init__(self, assistant_confidence_threshold):
self.assistant_confidence_threshold = assistant_confidence_threshold
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
probs = scores[-1].softmax(-1)
p = probs[0, input_ids[0, -1]].item()
if p < self.assistant_confidence_threshold:
return True
return False
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:

View File

@ -97,6 +97,7 @@ from .logits_process import (
WatermarkLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
@ -958,6 +959,13 @@ class GenerationMixin:
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
if (
generation_config.assistant_confidence_threshold is not None
and generation_config.assistant_confidence_threshold > 0
):
criteria.append(
ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
)
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria

View File

@ -26,6 +26,7 @@ if is_torch_available():
import torch
from transformers.generation import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
@ -100,6 +101,23 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
def test_confidence_criteria(self):
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
vocab_size = 250
length = 5
input_ids = ids_tensor((1, length), vocab_size)
scores = (torch.randn((1, vocab_size)),)
# Simulate high confidence by setting the probability of the last token to be high
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
self.assertFalse(criteria(input_ids, scores))
# Simulate low confidence by setting the probability of the last token to be low
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
self.assertTrue(criteria(input_ids, scores))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)