mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Generate: fix speculative decoding (#28166)
Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
parent
01c081d138
commit
45b70384a7
@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
|
||||
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
||||
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md).
|
||||
|
||||
### Assisted Decoding
|
||||
### Speculative Decoding
|
||||
|
||||
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
|
||||
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
|
||||
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
|
||||
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted
|
||||
decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
|
||||
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
|
||||
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
|
||||
`do_sample=True`, then the token validation with resampling introduced in the
|
||||
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||
|
||||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
|
||||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||
|
||||
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
|
||||
@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
|
@ -14,14 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .logits_process import LogitsProcessorList
|
||||
|
||||
|
||||
@ -66,14 +66,17 @@ class CandidateGenerator:
|
||||
|
||||
class AssistedCandidateGenerator(CandidateGenerator):
|
||||
"""
|
||||
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of
|
||||
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation
|
||||
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
||||
candidates through the use of a smaller model. Read the following blog post for more information:
|
||||
https://huggingface.co/blog/assisted-generation
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
assistant_model (`PreTrainedModel`):
|
||||
The model to be used for generating candidates. This model should be smaller than the main model.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
@ -82,30 +85,19 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
model as well.
|
||||
inputs_tensor (`torch.Tensor`, *optional*):
|
||||
The model input tensor. In encoder-decoder models, this is the encoder input.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
generation_config: "GenerationConfig",
|
||||
logits_processor: "LogitsProcessorList",
|
||||
model_kwargs: Dict,
|
||||
inputs_tensor: Optional[torch.Tensor] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
# Prepare the assistant and the starting number of candidate tokens
|
||||
self.assistant_model = assistant_model
|
||||
|
||||
# Prepare the number of candidate tokens
|
||||
if hasattr(assistant_model, "num_assistant_tokens"):
|
||||
warnings.warn(
|
||||
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
|
||||
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.num_assistant_tokens = assistant_model.num_assistant_tokens
|
||||
else:
|
||||
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||
|
||||
# Prepare the kwargs for the assistant model
|
||||
@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
self.input_ids_key = "input_ids"
|
||||
self.attention_key = "attention_mask"
|
||||
|
||||
# Prepare other attributes
|
||||
# Prepare generation-related options.
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
)
|
||||
self.logits_processor = logits_processor
|
||||
self.generation_config = copy.deepcopy(generation_config)
|
||||
self.generation_config.return_dict_in_generate = True
|
||||
self.generation_config.output_scores = True
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# 2. Forecast next N tokens using the assistant model.
|
||||
assistant_generation_kwargs = {
|
||||
self.input_ids_key: input_ids,
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": int(self.num_assistant_tokens),
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"generation_config": self.generation_config,
|
||||
"logits_processor": self.logits_processor,
|
||||
}
|
||||
|
||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
||||
|
||||
# 3. Update variables for the next round of candidate generation
|
||||
|
@ -911,10 +911,10 @@ class GenerationMixin:
|
||||
candidate_generator = AssistedCandidateGenerator(
|
||||
input_ids=input_ids,
|
||||
assistant_model=assistant_model,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
model_kwargs=model_kwargs,
|
||||
inputs_tensor=inputs_tensor,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
)
|
||||
return candidate_generator
|
||||
|
||||
@ -1673,7 +1673,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 8. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
prepared_logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
@ -1685,7 +1685,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 9. prepare stopping criteria
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
# 10. go into different generation modes
|
||||
@ -1715,9 +1715,9 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
candidate_generator=candidate_generator,
|
||||
do_sample=generation_config.do_sample,
|
||||
logits_processor=logits_processor,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
||||
stopping_criteria=stopping_criteria,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1730,8 +1730,8 @@ class GenerationMixin:
|
||||
# 11. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1749,8 +1749,8 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
top_k=generation_config.top_k,
|
||||
penalty_alpha=generation_config.penalty_alpha,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1776,9 +1776,9 @@ class GenerationMixin:
|
||||
# 13. run sample
|
||||
return self.sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1810,8 +1810,8 @@ class GenerationMixin:
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1847,9 +1847,9 @@ class GenerationMixin:
|
||||
return self.beam_sample(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1881,8 +1881,8 @@ class GenerationMixin:
|
||||
return self.group_beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -1954,8 +1954,8 @@ class GenerationMixin:
|
||||
return self.constrained_beam_search(
|
||||
input_ids,
|
||||
constrained_beam_scorer=constrained_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
@ -4629,7 +4629,7 @@ class GenerationMixin:
|
||||
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
||||
max_matches = max_len - cur_len - 1
|
||||
if do_sample and candidate_logits is not None:
|
||||
next_sampled_tokens, n_matches = _speculative_sampling(
|
||||
valid_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
@ -4637,8 +4637,6 @@ class GenerationMixin:
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
# The selected tokens include the matches plus the next sampled tokens
|
||||
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1)
|
||||
|
||||
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
||||
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
||||
@ -4657,6 +4655,7 @@ class GenerationMixin:
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
n_matches -= 1
|
||||
n_matches = min(n_matches, max_matches)
|
||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||
|
||||
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
||||
@ -4664,7 +4663,6 @@ class GenerationMixin:
|
||||
# is no match.
|
||||
|
||||
# 4.1. Get the valid continuation, after the matching tokens
|
||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(valid_tokens.cpu())
|
||||
@ -4782,24 +4780,16 @@ def _speculative_sampling(
|
||||
):
|
||||
"""
|
||||
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
||||
the next selected token, as well as the number of candidate matches.
|
||||
the selected tokens, as well as the number of candidate matches.
|
||||
|
||||
NOTE: Unless otherwise stated, the variable names match those in the paper.
|
||||
"""
|
||||
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
|
||||
# selected by the assistant, respectively.
|
||||
q = candidate_logits.softmax(dim=-1)
|
||||
q_i = q[
|
||||
:,
|
||||
torch.range(0, candidate_length - 1, dtype=torch.int),
|
||||
candidate_input_ids[:, -candidate_length:],
|
||||
].squeeze(0, 1)
|
||||
q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||
p = new_logits.softmax(dim=-1)
|
||||
p_i = p[
|
||||
:,
|
||||
torch.range(0, candidate_length - 1, dtype=torch.int),
|
||||
candidate_input_ids[:, -candidate_length:],
|
||||
].squeeze(0, 1)
|
||||
p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||
probability_ratio = p_i / q_i
|
||||
|
||||
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
|
||||
@ -4824,7 +4814,13 @@ def _speculative_sampling(
|
||||
p_prime = p_n_plus_1
|
||||
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
||||
|
||||
return t, n_matches
|
||||
# The selected tokens include the matches (if any) plus the next sampled tokens
|
||||
if n_matches > 0:
|
||||
valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1)
|
||||
else:
|
||||
valid_tokens = t
|
||||
|
||||
return valid_tokens, n_matches
|
||||
|
||||
|
||||
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
||||
|
@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = copy.deepcopy(self.generation_config)
|
||||
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
if return_dict_in_generate is not None
|
||||
else self.generation_config.return_dict_in_generate
|
||||
else generation_config.return_dict_in_generate
|
||||
)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = copy.deepcopy(self.generation_config)
|
||||
|
||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
||||
if num_segment_frames is None:
|
||||
num_segment_frames = input_stride * self.config.max_source_positions
|
||||
|
@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_bitsandbytes,
|
||||
@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
||||
)
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
set_seed(0)
|
||||
generated_ids = model.generate(
|
||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
Loading…
Reference in New Issue
Block a user