Generate: fix speculative decoding (#28166)

Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
Joao Gante 2023-12-20 18:55:35 +00:00 committed by GitHub
parent 01c081d138
commit 45b70384a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 72 deletions

View File

@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still
commonly adjusted parameters include: commonly adjusted parameters include:
- `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not - `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [`StoppingCriteria`]. to stop generation whenever the full generation exceeds some amount of time. To learn more, check [`StoppingCriteria`].
- `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to - `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[`generate`] method, which gives you even further control over the [`generate`] method's behavior. [`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md). For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md).
### Assisted Decoding ### Speculative Decoding
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted `do_sample=True`, then the token validation with resampling introduced in the
decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
To enable assisted decoding, set the `assistant_model` argument with a model. To enable assisted decoding, set the `assistant_model` argument with a model.
@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
``` ```
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency. just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
```python ```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed >>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

View File

@ -14,14 +14,14 @@
# limitations under the License. # limitations under the License.
import copy import copy
import warnings from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from .configuration_utils import GenerationConfig
from .logits_process import LogitsProcessorList from .logits_process import LogitsProcessorList
@ -66,14 +66,17 @@ class CandidateGenerator:
class AssistedCandidateGenerator(CandidateGenerator): class AssistedCandidateGenerator(CandidateGenerator):
""" """
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation candidates through the use of a smaller model. Read the following blog post for more information:
https://huggingface.co/blog/assisted-generation
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`): assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model. The model to be used for generating candidates. This model should be smaller than the main model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`): logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step. used to modify the prediction scores of the language modeling head applied at each generation step.
@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator):
model as well. model as well.
inputs_tensor (`torch.Tensor`, *optional*): inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input. The model input tensor. In encoder-decoder models, this is the encoder input.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
""" """
def __init__( def __init__(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel", assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
logits_processor: "LogitsProcessorList", logits_processor: "LogitsProcessorList",
model_kwargs: Dict, model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None, inputs_tensor: Optional[torch.Tensor] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
): ):
# Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model self.assistant_model = assistant_model
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# Prepare the number of candidate tokens
if hasattr(assistant_model, "num_assistant_tokens"):
warnings.warn(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
FutureWarning,
)
self.num_assistant_tokens = assistant_model.num_assistant_tokens
else:
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# Prepare the kwargs for the assistant model # Prepare the kwargs for the assistant model
assistant_kwargs = {} assistant_kwargs = {}
@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.input_ids_key = "input_ids" self.input_ids_key = "input_ids"
self.attention_key = "attention_mask" self.attention_key = "attention_mask"
# Prepare other attributes # Prepare generation-related options.
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
self.eos_token_id_tensor = ( self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
) )
self.logits_processor = logits_processor self.logits_processor = logits_processor
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 2. Forecast next N tokens using the assistant model. # 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = { assistant_generation_kwargs = {
self.input_ids_key: input_ids, self.input_ids_key: input_ids,
"do_sample": False,
"num_beams": 1,
"max_new_tokens": int(self.num_assistant_tokens), "max_new_tokens": int(self.num_assistant_tokens),
"return_dict_in_generate": True, "generation_config": self.generation_config,
"output_scores": True, "logits_processor": self.logits_processor,
} }
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
# 3. Update variables for the next round of candidate generation # 3. Update variables for the next round of candidate generation

View File

@ -911,10 +911,10 @@ class GenerationMixin:
candidate_generator = AssistedCandidateGenerator( candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids, input_ids=input_ids,
assistant_model=assistant_model, assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor, logits_processor=logits_processor,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor, inputs_tensor=inputs_tensor,
eos_token_id=generation_config.eos_token_id,
) )
return candidate_generator return candidate_generator
@ -1673,7 +1673,7 @@ class GenerationMixin:
) )
# 8. prepare distribution pre_processing samplers # 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_length, input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
@ -1685,7 +1685,7 @@ class GenerationMixin:
) )
# 9. prepare stopping criteria # 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria( prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
# 10. go into different generation modes # 10. go into different generation modes
@ -1715,9 +1715,9 @@ class GenerationMixin:
input_ids, input_ids,
candidate_generator=candidate_generator, candidate_generator=candidate_generator,
do_sample=generation_config.do_sample, do_sample=generation_config.do_sample,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1730,8 +1730,8 @@ class GenerationMixin:
# 11. run greedy search # 11. run greedy search
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1749,8 +1749,8 @@ class GenerationMixin:
input_ids, input_ids,
top_k=generation_config.top_k, top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha, penalty_alpha=generation_config.penalty_alpha,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1776,9 +1776,9 @@ class GenerationMixin:
# 13. run sample # 13. run sample
return self.sample( return self.sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1810,8 +1810,8 @@ class GenerationMixin:
return self.beam_search( return self.beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1847,9 +1847,9 @@ class GenerationMixin:
return self.beam_sample( return self.beam_sample(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1881,8 +1881,8 @@ class GenerationMixin:
return self.group_beam_search( return self.group_beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -1954,8 +1954,8 @@ class GenerationMixin:
return self.constrained_beam_search( return self.constrained_beam_search(
input_ids, input_ids,
constrained_beam_scorer=constrained_beam_scorer, constrained_beam_scorer=constrained_beam_scorer,
logits_processor=logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
@ -4629,7 +4629,7 @@ class GenerationMixin:
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
max_matches = max_len - cur_len - 1 max_matches = max_len - cur_len - 1
if do_sample and candidate_logits is not None: if do_sample and candidate_logits is not None:
next_sampled_tokens, n_matches = _speculative_sampling( valid_tokens, n_matches = _speculative_sampling(
candidate_input_ids, candidate_input_ids,
candidate_logits, candidate_logits,
candidate_length, candidate_length,
@ -4637,8 +4637,6 @@ class GenerationMixin:
last_assistant_token_is_eos, last_assistant_token_is_eos,
max_matches, max_matches,
) )
# The selected tokens include the matches plus the next sampled tokens
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1)
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first # original model logits with the candidate tokens. We can keep the candidate tokens until the first
@ -4657,6 +4655,7 @@ class GenerationMixin:
if last_assistant_token_is_eos and n_matches == candidate_length: if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1 n_matches -= 1
n_matches = min(n_matches, max_matches) n_matches = min(n_matches, max_matches)
valid_tokens = selected_tokens[:, : n_matches + 1]
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence. # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
@ -4664,7 +4663,6 @@ class GenerationMixin:
# is no match. # is no match.
# 4.1. Get the valid continuation, after the matching tokens # 4.1. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1) input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None: if streamer is not None:
streamer.put(valid_tokens.cpu()) streamer.put(valid_tokens.cpu())
@ -4782,24 +4780,16 @@ def _speculative_sampling(
): ):
""" """
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the next selected token, as well as the number of candidate matches. the selected tokens, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper. NOTE: Unless otherwise stated, the variable names match those in the paper.
""" """
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively. # selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1) q = candidate_logits.softmax(dim=-1)
q_i = q[ q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
p = new_logits.softmax(dim=-1) p = new_logits.softmax(dim=-1)
p_i = p[ p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
probability_ratio = p_i / q_i probability_ratio = p_i / q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
@ -4824,7 +4814,13 @@ def _speculative_sampling(
p_prime = p_n_plus_1 p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
return t, n_matches # The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1)
else:
valid_tokens = t
return valid_tokens, n_matches
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):

View File

@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
FutureWarning, FutureWarning,
) )
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
return_dict_in_generate = ( return_dict_in_generate = (
return_dict_in_generate return_dict_in_generate
if return_dict_in_generate is not None if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate else generation_config.return_dict_in_generate
) )
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
if num_segment_frames is None: if num_segment_frames is None:
num_segment_frames = input_stride * self.config.max_source_positions num_segment_frames = input_stride * self.config.max_source_positions

View File

@ -21,7 +21,7 @@ import unittest
import pytest import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, backend_empty_cache,
require_bitsandbytes, require_bitsandbytes,
@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
del model del model
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@slow
def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = (
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
)
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
set_seed(0)
generated_ids = model.generate(
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()