From 45b70384a7d6692a8304f34a981a5ff020918b82 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 20 Dec 2023 18:55:35 +0000 Subject: [PATCH] Generate: fix speculative decoding (#28166) Co-authored-by: Merve Noyan --- docs/source/en/generation_strategies.md | 21 +++--- .../generation/candidate_generator.py | 41 +++++------- src/transformers/generation/utils.py | 66 +++++++++---------- .../models/whisper/modeling_whisper.py | 8 +-- tests/models/mistral/test_modeling_mistral.py | 26 +++++++- 5 files changed, 90 insertions(+), 72 deletions(-) diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 51f92e06103..df91c36c610 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still commonly adjusted parameters include: - `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not -including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose +including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose to stop generation whenever the full generation exceeds some amount of time. To learn more, check [`StoppingCriteria`]. - `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that @@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi [`generate`] method, which gives you even further control over the [`generate`] method's behavior. For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md). -### Assisted Decoding +### Speculative Decoding -Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same -tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates -the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search -and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted -decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). +Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an +assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main +model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If +`do_sample=True`, then the token validation with resampling introduced in the +[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used. + +Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. +To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). To enable assisted decoding, set the `assistant_model` argument with a model. @@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model. ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness -just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency. +When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, +just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ccfd4cfad71..bb82b852f00 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -14,14 +14,14 @@ # limitations under the License. import copy -import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import torch if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel + from .configuration_utils import GenerationConfig from .logits_process import LogitsProcessorList @@ -66,14 +66,17 @@ class CandidateGenerator: class AssistedCandidateGenerator(CandidateGenerator): """ - `CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of - a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation + `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates + candidates through the use of a smaller model. Read the following blog post for more information: + https://huggingface.co/blog/assisted-generation Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) assistant_model (`PreTrainedModel`): The model to be used for generating candidates. This model should be smaller than the main model. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. @@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator): model as well. inputs_tensor (`torch.Tensor`, *optional*): The model input tensor. In encoder-decoder models, this is the encoder input. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__( self, input_ids: torch.LongTensor, assistant_model: "PreTrainedModel", + generation_config: "GenerationConfig", logits_processor: "LogitsProcessorList", model_kwargs: Dict, inputs_tensor: Optional[torch.Tensor] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, ): + # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model - - # Prepare the number of candidate tokens - if hasattr(assistant_model, "num_assistant_tokens"): - warnings.warn( - "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be " - "removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", - FutureWarning, - ) - self.num_assistant_tokens = assistant_model.num_assistant_tokens - else: - self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # Prepare the kwargs for the assistant model assistant_kwargs = {} @@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator): self.input_ids_key = "input_ids" self.attention_key = "attention_mask" - # Prepare other attributes + # Prepare generation-related options. + eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] self.eos_token_id_tensor = ( torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None ) self.logits_processor = logits_processor + self.generation_config = copy.deepcopy(generation_config) + self.generation_config.return_dict_in_generate = True + self.generation_config.output_scores = True def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ @@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator): # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, - "do_sample": False, - "num_beams": 1, "max_new_tokens": int(self.num_assistant_tokens), - "return_dict_in_generate": True, - "output_scores": True, + "generation_config": self.generation_config, + "logits_processor": self.logits_processor, } + assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) # 3. Update variables for the next round of candidate generation diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b3bc4cd8d87..c7ae4aee7f8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -911,10 +911,10 @@ class GenerationMixin: candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, assistant_model=assistant_model, + generation_config=generation_config, logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, - eos_token_id=generation_config.eos_token_id, ) return candidate_generator @@ -1673,7 +1673,7 @@ class GenerationMixin: ) # 8. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( + prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, @@ -1685,7 +1685,7 @@ class GenerationMixin: ) # 9. prepare stopping criteria - stopping_criteria = self._get_stopping_criteria( + prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes @@ -1715,9 +1715,9 @@ class GenerationMixin: input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, - logits_processor=logits_processor, + logits_processor=prepared_logits_processor, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, - stopping_criteria=stopping_criteria, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1730,8 +1730,8 @@ class GenerationMixin: # 11. run greedy search return self.greedy_search( input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1749,8 +1749,8 @@ class GenerationMixin: input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1776,9 +1776,9 @@ class GenerationMixin: # 13. run sample return self.sample( input_ids, - logits_processor=logits_processor, + logits_processor=prepared_logits_processor, logits_warper=logits_warper, - stopping_criteria=stopping_criteria, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1810,8 +1810,8 @@ class GenerationMixin: return self.beam_search( input_ids, beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1847,9 +1847,9 @@ class GenerationMixin: return self.beam_sample( input_ids, beam_scorer, - logits_processor=logits_processor, + logits_processor=prepared_logits_processor, logits_warper=logits_warper, - stopping_criteria=stopping_criteria, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1881,8 +1881,8 @@ class GenerationMixin: return self.group_beam_search( input_ids, beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -1954,8 +1954,8 @@ class GenerationMixin: return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, @@ -4629,7 +4629,7 @@ class GenerationMixin: # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). max_matches = max_len - cur_len - 1 if do_sample and candidate_logits is not None: - next_sampled_tokens, n_matches = _speculative_sampling( + valid_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, @@ -4637,8 +4637,6 @@ class GenerationMixin: last_assistant_token_is_eos, max_matches, ) - # The selected tokens include the matches plus the next sampled tokens - selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1) # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the # original model logits with the candidate tokens. We can keep the candidate tokens until the first @@ -4657,6 +4655,7 @@ class GenerationMixin: if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 n_matches = min(n_matches, max_matches) + valid_tokens = selected_tokens[:, : n_matches + 1] # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. @@ -4664,7 +4663,6 @@ class GenerationMixin: # is no match. # 4.1. Get the valid continuation, after the matching tokens - valid_tokens = selected_tokens[:, : n_matches + 1] input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) @@ -4782,24 +4780,16 @@ def _speculative_sampling( ): """ Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns - the next selected token, as well as the number of candidate matches. + the selected tokens, as well as the number of candidate matches. NOTE: Unless otherwise stated, the variable names match those in the paper. """ # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens # selected by the assistant, respectively. q = candidate_logits.softmax(dim=-1) - q_i = q[ - :, - torch.range(0, candidate_length - 1, dtype=torch.int), - candidate_input_ids[:, -candidate_length:], - ].squeeze(0, 1) + q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1) p = new_logits.softmax(dim=-1) - p_i = p[ - :, - torch.range(0, candidate_length - 1, dtype=torch.int), - candidate_input_ids[:, -candidate_length:], - ].squeeze(0, 1) + p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1) probability_ratio = p_i / q_i # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller @@ -4824,7 +4814,13 @@ def _speculative_sampling( p_prime = p_n_plus_1 t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - return t, n_matches + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1) + else: + valid_tokens = t + + return valid_tokens, n_matches def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e7bcb47acdf..fb2bce476ea 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): FutureWarning, ) + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) - if generation_config is None: - generation_config = copy.deepcopy(self.generation_config) - input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] if num_segment_frames is None: num_segment_frames = input_stride * self.config.max_source_positions diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 35a2341b4e6..ed14bdfeb58 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -21,7 +21,7 @@ import unittest import pytest -from transformers import AutoTokenizer, MistralConfig, is_torch_available +from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( backend_empty_cache, require_bitsandbytes, @@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase): del model backend_empty_cache(torch_device) gc.collect() + + @slow + def test_speculative_generation(self): + EXPECTED_TEXT_COMPLETION = ( + "My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16 + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + set_seed(0) + generated_ids = model.generate( + input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect()