mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Generate: fix speculative decoding (#28166)
Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
parent
01c081d138
commit
45b70384a7
@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still
|
|||||||
commonly adjusted parameters include:
|
commonly adjusted parameters include:
|
||||||
|
|
||||||
- `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not
|
- `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not
|
||||||
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
|
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
|
||||||
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [`StoppingCriteria`].
|
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [`StoppingCriteria`].
|
||||||
- `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
|
- `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
|
||||||
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
|
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
|
||||||
@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
|
|||||||
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
||||||
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md).
|
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.md).
|
||||||
|
|
||||||
### Assisted Decoding
|
### Speculative Decoding
|
||||||
|
|
||||||
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
|
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
|
||||||
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
|
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
|
||||||
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
|
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
|
||||||
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted
|
`do_sample=True`, then the token validation with resampling introduced in the
|
||||||
decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||||
|
|
||||||
|
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
|
||||||
|
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||||
|
|
||||||
To enable assisted decoding, set the `assistant_model` argument with a model.
|
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||||
|
|
||||||
@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
|||||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||||
```
|
```
|
||||||
|
|
||||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness
|
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
|
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||||
|
@ -14,14 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
from .configuration_utils import GenerationConfig
|
||||||
from .logits_process import LogitsProcessorList
|
from .logits_process import LogitsProcessorList
|
||||||
|
|
||||||
|
|
||||||
@ -66,14 +66,17 @@ class CandidateGenerator:
|
|||||||
|
|
||||||
class AssistedCandidateGenerator(CandidateGenerator):
|
class AssistedCandidateGenerator(CandidateGenerator):
|
||||||
"""
|
"""
|
||||||
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of
|
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
||||||
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation
|
candidates through the use of a smaller model. Read the following blog post for more information:
|
||||||
|
https://huggingface.co/blog/assisted-generation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||||
assistant_model (`PreTrainedModel`):
|
assistant_model (`PreTrainedModel`):
|
||||||
The model to be used for generating candidates. This model should be smaller than the main model.
|
The model to be used for generating candidates. This model should be smaller than the main model.
|
||||||
|
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||||
|
The generation configuration to be used as base parametrization for the generation call.
|
||||||
logits_processor (`LogitsProcessorList`):
|
logits_processor (`LogitsProcessorList`):
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
model as well.
|
model as well.
|
||||||
inputs_tensor (`torch.Tensor`, *optional*):
|
inputs_tensor (`torch.Tensor`, *optional*):
|
||||||
The model input tensor. In encoder-decoder models, this is the encoder input.
|
The model input tensor. In encoder-decoder models, this is the encoder input.
|
||||||
eos_token_id (`Union[int, List[int]]`, *optional*):
|
|
||||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
assistant_model: "PreTrainedModel",
|
assistant_model: "PreTrainedModel",
|
||||||
|
generation_config: "GenerationConfig",
|
||||||
logits_processor: "LogitsProcessorList",
|
logits_processor: "LogitsProcessorList",
|
||||||
model_kwargs: Dict,
|
model_kwargs: Dict,
|
||||||
inputs_tensor: Optional[torch.Tensor] = None,
|
inputs_tensor: Optional[torch.Tensor] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
|
||||||
):
|
):
|
||||||
|
# Prepare the assistant and the starting number of candidate tokens
|
||||||
self.assistant_model = assistant_model
|
self.assistant_model = assistant_model
|
||||||
|
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||||
# Prepare the number of candidate tokens
|
|
||||||
if hasattr(assistant_model, "num_assistant_tokens"):
|
|
||||||
warnings.warn(
|
|
||||||
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
|
|
||||||
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
self.num_assistant_tokens = assistant_model.num_assistant_tokens
|
|
||||||
else:
|
|
||||||
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
|
||||||
|
|
||||||
# Prepare the kwargs for the assistant model
|
# Prepare the kwargs for the assistant model
|
||||||
assistant_kwargs = {}
|
assistant_kwargs = {}
|
||||||
@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
self.input_ids_key = "input_ids"
|
self.input_ids_key = "input_ids"
|
||||||
self.attention_key = "attention_mask"
|
self.attention_key = "attention_mask"
|
||||||
|
|
||||||
# Prepare other attributes
|
# Prepare generation-related options.
|
||||||
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
self.eos_token_id_tensor = (
|
self.eos_token_id_tensor = (
|
||||||
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
)
|
)
|
||||||
self.logits_processor = logits_processor
|
self.logits_processor = logits_processor
|
||||||
|
self.generation_config = copy.deepcopy(generation_config)
|
||||||
|
self.generation_config.return_dict_in_generate = True
|
||||||
|
self.generation_config.output_scores = True
|
||||||
|
|
||||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
# 2. Forecast next N tokens using the assistant model.
|
# 2. Forecast next N tokens using the assistant model.
|
||||||
assistant_generation_kwargs = {
|
assistant_generation_kwargs = {
|
||||||
self.input_ids_key: input_ids,
|
self.input_ids_key: input_ids,
|
||||||
"do_sample": False,
|
|
||||||
"num_beams": 1,
|
|
||||||
"max_new_tokens": int(self.num_assistant_tokens),
|
"max_new_tokens": int(self.num_assistant_tokens),
|
||||||
"return_dict_in_generate": True,
|
"generation_config": self.generation_config,
|
||||||
"output_scores": True,
|
"logits_processor": self.logits_processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
||||||
|
|
||||||
# 3. Update variables for the next round of candidate generation
|
# 3. Update variables for the next round of candidate generation
|
||||||
|
@ -911,10 +911,10 @@ class GenerationMixin:
|
|||||||
candidate_generator = AssistedCandidateGenerator(
|
candidate_generator = AssistedCandidateGenerator(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
|
generation_config=generation_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
inputs_tensor=inputs_tensor,
|
inputs_tensor=inputs_tensor,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
)
|
)
|
||||||
return candidate_generator
|
return candidate_generator
|
||||||
|
|
||||||
@ -1673,7 +1673,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 8. prepare distribution pre_processing samplers
|
# 8. prepare distribution pre_processing samplers
|
||||||
logits_processor = self._get_logits_processor(
|
prepared_logits_processor = self._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=input_ids_length,
|
input_ids_seq_length=input_ids_length,
|
||||||
encoder_input_ids=inputs_tensor,
|
encoder_input_ids=inputs_tensor,
|
||||||
@ -1685,7 +1685,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 9. prepare stopping criteria
|
# 9. prepare stopping criteria
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
)
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
@ -1715,9 +1715,9 @@ class GenerationMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
candidate_generator=candidate_generator,
|
candidate_generator=candidate_generator,
|
||||||
do_sample=generation_config.do_sample,
|
do_sample=generation_config.do_sample,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1730,8 +1730,8 @@ class GenerationMixin:
|
|||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
return self.greedy_search(
|
return self.greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1749,8 +1749,8 @@ class GenerationMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
top_k=generation_config.top_k,
|
top_k=generation_config.top_k,
|
||||||
penalty_alpha=generation_config.penalty_alpha,
|
penalty_alpha=generation_config.penalty_alpha,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1776,9 +1776,9 @@ class GenerationMixin:
|
|||||||
# 13. run sample
|
# 13. run sample
|
||||||
return self.sample(
|
return self.sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1810,8 +1810,8 @@ class GenerationMixin:
|
|||||||
return self.beam_search(
|
return self.beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1847,9 +1847,9 @@ class GenerationMixin:
|
|||||||
return self.beam_sample(
|
return self.beam_sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1881,8 +1881,8 @@ class GenerationMixin:
|
|||||||
return self.group_beam_search(
|
return self.group_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -1954,8 +1954,8 @@ class GenerationMixin:
|
|||||||
return self.constrained_beam_search(
|
return self.constrained_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
constrained_beam_scorer=constrained_beam_scorer,
|
constrained_beam_scorer=constrained_beam_scorer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
@ -4629,7 +4629,7 @@ class GenerationMixin:
|
|||||||
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
||||||
max_matches = max_len - cur_len - 1
|
max_matches = max_len - cur_len - 1
|
||||||
if do_sample and candidate_logits is not None:
|
if do_sample and candidate_logits is not None:
|
||||||
next_sampled_tokens, n_matches = _speculative_sampling(
|
valid_tokens, n_matches = _speculative_sampling(
|
||||||
candidate_input_ids,
|
candidate_input_ids,
|
||||||
candidate_logits,
|
candidate_logits,
|
||||||
candidate_length,
|
candidate_length,
|
||||||
@ -4637,8 +4637,6 @@ class GenerationMixin:
|
|||||||
last_assistant_token_is_eos,
|
last_assistant_token_is_eos,
|
||||||
max_matches,
|
max_matches,
|
||||||
)
|
)
|
||||||
# The selected tokens include the matches plus the next sampled tokens
|
|
||||||
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1)
|
|
||||||
|
|
||||||
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
||||||
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
||||||
@ -4657,6 +4655,7 @@ class GenerationMixin:
|
|||||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||||
n_matches -= 1
|
n_matches -= 1
|
||||||
n_matches = min(n_matches, max_matches)
|
n_matches = min(n_matches, max_matches)
|
||||||
|
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||||
|
|
||||||
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||||
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
||||||
@ -4664,7 +4663,6 @@ class GenerationMixin:
|
|||||||
# is no match.
|
# is no match.
|
||||||
|
|
||||||
# 4.1. Get the valid continuation, after the matching tokens
|
# 4.1. Get the valid continuation, after the matching tokens
|
||||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
|
||||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(valid_tokens.cpu())
|
streamer.put(valid_tokens.cpu())
|
||||||
@ -4782,24 +4780,16 @@ def _speculative_sampling(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
||||||
the next selected token, as well as the number of candidate matches.
|
the selected tokens, as well as the number of candidate matches.
|
||||||
|
|
||||||
NOTE: Unless otherwise stated, the variable names match those in the paper.
|
NOTE: Unless otherwise stated, the variable names match those in the paper.
|
||||||
"""
|
"""
|
||||||
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
|
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
|
||||||
# selected by the assistant, respectively.
|
# selected by the assistant, respectively.
|
||||||
q = candidate_logits.softmax(dim=-1)
|
q = candidate_logits.softmax(dim=-1)
|
||||||
q_i = q[
|
q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||||
:,
|
|
||||||
torch.range(0, candidate_length - 1, dtype=torch.int),
|
|
||||||
candidate_input_ids[:, -candidate_length:],
|
|
||||||
].squeeze(0, 1)
|
|
||||||
p = new_logits.softmax(dim=-1)
|
p = new_logits.softmax(dim=-1)
|
||||||
p_i = p[
|
p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
|
||||||
:,
|
|
||||||
torch.range(0, candidate_length - 1, dtype=torch.int),
|
|
||||||
candidate_input_ids[:, -candidate_length:],
|
|
||||||
].squeeze(0, 1)
|
|
||||||
probability_ratio = p_i / q_i
|
probability_ratio = p_i / q_i
|
||||||
|
|
||||||
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
|
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
|
||||||
@ -4824,7 +4814,13 @@ def _speculative_sampling(
|
|||||||
p_prime = p_n_plus_1
|
p_prime = p_n_plus_1
|
||||||
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
||||||
|
|
||||||
return t, n_matches
|
# The selected tokens include the matches (if any) plus the next sampled tokens
|
||||||
|
if n_matches > 0:
|
||||||
|
valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1)
|
||||||
|
else:
|
||||||
|
valid_tokens = t
|
||||||
|
|
||||||
|
return valid_tokens, n_matches
|
||||||
|
|
||||||
|
|
||||||
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
||||||
|
@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if generation_config is None:
|
||||||
|
generation_config = copy.deepcopy(self.generation_config)
|
||||||
|
|
||||||
return_dict_in_generate = (
|
return_dict_in_generate = (
|
||||||
return_dict_in_generate
|
return_dict_in_generate
|
||||||
if return_dict_in_generate is not None
|
if return_dict_in_generate is not None
|
||||||
else self.generation_config.return_dict_in_generate
|
else generation_config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config is None:
|
|
||||||
generation_config = copy.deepcopy(self.generation_config)
|
|
||||||
|
|
||||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
||||||
if num_segment_frames is None:
|
if num_segment_frames is None:
|
||||||
num_segment_frames = input_stride * self.config.max_source_positions
|
num_segment_frames = input_stride * self.config.max_source_positions
|
||||||
|
@ -21,7 +21,7 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
del model
|
del model
|
||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_speculative_generation(self):
|
||||||
|
EXPECTED_TEXT_COMPLETION = (
|
||||||
|
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
||||||
|
)
|
||||||
|
prompt = "My favourite condiment is "
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||||
|
model = MistralForCausalLM.from_pretrained(
|
||||||
|
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||||
|
|
||||||
|
# greedy generation outputs
|
||||||
|
set_seed(0)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||||
|
)
|
||||||
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||||
|
|
||||||
|
del model
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
gc.collect()
|
||||||
|
Loading…
Reference in New Issue
Block a user