mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)
* Update candidate_generator.py * Update utils.py * add lookbehind params to _get_candidate_generator * make fixup * add unit tests * fix failing tests * add docstrings * fix docstrings; remove non-optimized AnyTokenizer * added any tokenizer generation correctness test * make fixup * fix assertion syntax * PR review fixes * address additional PR comments * fix tests * remove stropping criteria arg * make fixup * add AssistantConfig * fix prev_tokens branching * pass tokenizers through `generate()`kwargs * fix lookbehind values; tokenizer params WIP * fixup * AssistantConfig * remove AssistantConfig; apply PR suggestions * restructure tests * fixup * fix assistant_tokenizer arg validation * fixup * fix tests in TestAssistedCandidateGeneratorDifferentTokenizers * fix class docstring * PR suggestions * doc * doc update and improvements to `_validate_assistant()` --------- Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
parent
dda3f91d06
commit
fb0c6b521d
@ -408,14 +408,24 @@ For the complete list of the available parameters, refer to the [API documentati
|
||||
### Speculative Decoding
|
||||
|
||||
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
|
||||
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
|
||||
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
|
||||
`do_sample=True`, then the token validation with resampling introduced in the
|
||||
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||
assistant model (ideally a much smaller one), to generate a few candidate tokens. The main model then validates the candidate
|
||||
tokens in a single forward pass, which speeds up the decoding process. If `do_sample=True`, then the token validation with
|
||||
resampling introduced in the [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||
Assisted decoding assumes the main and assistant models have the same tokenizer, otherwise, see Universal Assisted Decoding below.
|
||||
|
||||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
|
||||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||
|
||||
#### Universal Assisted Decoding
|
||||
|
||||
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
|
||||
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
|
||||
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
||||
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
||||
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
||||
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
||||
to ensure the new tokens include the correct prompt suffix.
|
||||
|
||||
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
|
||||
```python
|
||||
@ -435,6 +445,26 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "google/gemma-2-9b"
|
||||
>>> assistant_checkpoint = "double7/vicuna-68m"
|
||||
|
||||
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||
|
||||
@ -458,6 +488,7 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
||||
|
||||
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||
|
||||
### DoLa Decoding
|
||||
|
||||
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
|
||||
|
@ -16,6 +16,7 @@
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..cache_utils import DynamicCache
|
||||
@ -25,6 +26,7 @@ from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .configuration_utils import GenerationConfig
|
||||
|
||||
|
||||
@ -156,6 +158,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# Prepare generation-related options.
|
||||
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
self.generation_config = copy.deepcopy(generation_config)
|
||||
|
||||
self.generation_config.return_dict_in_generate = True
|
||||
self.generation_config.output_scores = True
|
||||
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
|
||||
@ -258,6 +261,303 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
|
||||
|
||||
|
||||
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
"""
|
||||
`CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
|
||||
for the assistant and main models. This class generates candidates through the use of a smaller
|
||||
model.
|
||||
|
||||
The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
||||
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
||||
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
||||
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
||||
to ensure the new tokens include the correct prompt suffix.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
assistant_model (`PreTrainedModel`):
|
||||
The model to be used for generating candidates. This model should be smaller than the main model.
|
||||
target_tokenizer (`PreTrainedTokenizerBase`):
|
||||
The tokenizer used for the target model.
|
||||
assistant_tokenizer (`PreTrainedTokenizerBase`):
|
||||
The tokenizer used for the assistant model.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
model_kwargs (`Dict`):
|
||||
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
||||
model as well.
|
||||
inputs_tensor (`torch.Tensor`, *optional*):
|
||||
The model input tensor. In encoder-decoder models, this is the encoder input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
target_tokenizer: "PreTrainedTokenizerBase",
|
||||
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||
generation_config: "GenerationConfig",
|
||||
model_kwargs: Dict,
|
||||
inputs_tensor: Optional[torch.Tensor] = None,
|
||||
logits_processor: "LogitsProcessorList" = None,
|
||||
):
|
||||
super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
|
||||
|
||||
self.target_tokenizer = target_tokenizer
|
||||
self.assistant_tokenizer = assistant_tokenizer
|
||||
self.prev_tokens = None
|
||||
self.prev_assistant_ids = None
|
||||
self.target_lookbehind = 10
|
||||
self.assistant_lookbehind = 10
|
||||
|
||||
@staticmethod
|
||||
def _get_longest_diag_dict(input_matrix, nonzero_idx):
|
||||
"""
|
||||
Calculates the length of the longest diagonal sequence in a given matrix.
|
||||
Args:
|
||||
input_matrix (torch.Tensor): The input matrix.
|
||||
nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
|
||||
Returns:
|
||||
dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
|
||||
"""
|
||||
|
||||
visited = set()
|
||||
diags = {}
|
||||
for idx in nonzero_idx:
|
||||
start_idx = torch.clone(idx)
|
||||
tuple_start_idx = tuple(start_idx.tolist())
|
||||
|
||||
if tuple_start_idx in visited:
|
||||
continue
|
||||
|
||||
visited.add(tuple_start_idx)
|
||||
cur_diag_len = 1
|
||||
start_idx += 1
|
||||
while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
|
||||
tuple_start_idx = tuple(start_idx.tolist())
|
||||
visited.add(tuple_start_idx)
|
||||
|
||||
if input_matrix[start_idx[0], start_idx[1]] == 1:
|
||||
cur_diag_len += 1
|
||||
start_idx += 1
|
||||
else:
|
||||
break
|
||||
|
||||
diags[idx] = cur_diag_len
|
||||
return diags
|
||||
|
||||
@staticmethod
|
||||
def _get_longest_diag_index(input_matrix):
|
||||
"""
|
||||
Returns the start index and length of the longest diagonal in the given input.
|
||||
Args:
|
||||
input_matrix (numpy.ndarray): The input matrix.
|
||||
Returns:
|
||||
tuple: A tuple containing the start index and length of the longest diagonal.
|
||||
"""
|
||||
|
||||
diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
|
||||
input_matrix, input_matrix.nonzero()
|
||||
)
|
||||
diags_values = list(diags.values())
|
||||
diags_keys = list(diags.keys())
|
||||
best_diag = np.argmax(diags_values)
|
||||
diag_start_index = diags_keys[best_diag]
|
||||
diag_start_length = diags_values[best_diag]
|
||||
return diag_start_index, diag_start_length
|
||||
|
||||
@staticmethod
|
||||
def _get_tokens_diag(prompt, prompt_plus_new_tokens):
|
||||
"""
|
||||
Input:
|
||||
prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
|
||||
prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
|
||||
Output:
|
||||
discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
|
||||
new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
|
||||
discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
|
||||
"""
|
||||
compare_mat = prompt_plus_new_tokens.T == prompt
|
||||
if not torch.is_tensor(compare_mat):
|
||||
compare_mat = torch.tensor(compare_mat)
|
||||
|
||||
compare_mat_int = compare_mat.to(int)
|
||||
|
||||
if not compare_mat_int.any().item():
|
||||
# empty intersection between prompt and prompt_plus_new_tokens
|
||||
return None, None, None
|
||||
|
||||
longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
|
||||
compare_mat_int
|
||||
)
|
||||
new_token_start_index = longest_location[0] + longest_diag_length
|
||||
discrepancy_with_old = longest_location[1] + longest_diag_length
|
||||
discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
|
||||
new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
|
||||
discrepancy_only = prompt_plus_new_tokens[
|
||||
:, new_token_start_index : new_token_start_index + discrepancy_length
|
||||
]
|
||||
return discrepancy_length, new_tokens_only, discrepancy_only
|
||||
|
||||
def convert_source_tokens_to_target_tokens(
|
||||
self,
|
||||
input_ids,
|
||||
source_tokenizer,
|
||||
destination_tokenizer,
|
||||
):
|
||||
"""
|
||||
Convert token IDs from one tokenizer to another.
|
||||
Args:
|
||||
input_ids: The input token IDs.
|
||||
source_tokenizer: The source tokenizer.
|
||||
destination_tokenizer: The destination tokenizer.
|
||||
Returns:
|
||||
The converted token IDs.
|
||||
"""
|
||||
text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
|
||||
return dest_ids.to(input_ids.device)
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
||||
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||
vocabulary_size)` containing the logits associated to each candidate.
|
||||
"""
|
||||
max_new_tokens = int(self.num_assistant_tokens)
|
||||
if max_new_tokens == 0:
|
||||
return input_ids, None
|
||||
|
||||
input_ids = input_ids.to(self.assistant_model.device)
|
||||
convert_kwargs = {
|
||||
"source_tokenizer": self.target_tokenizer,
|
||||
"destination_tokenizer": self.assistant_tokenizer,
|
||||
}
|
||||
remove_from_pkv = 0
|
||||
|
||||
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
|
||||
# (one for each conversion) which mark where to start looking for the overlap between the
|
||||
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
|
||||
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
|
||||
# input_ids contains all target prompt input ids and some new target input ids
|
||||
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
|
||||
|
||||
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
||||
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
||||
)
|
||||
prompt_use_length = new_assistant_ids.shape[1]
|
||||
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
|
||||
|
||||
discrepancy_length, new_tokens_only, discrepancy_only = (
|
||||
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
|
||||
)
|
||||
assistant_input_ids = self.prev_assistant_ids
|
||||
|
||||
if new_tokens_only is not None:
|
||||
if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
|
||||
if discrepancy_length == discrepancy_only.shape[1]:
|
||||
assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
|
||||
|
||||
elif discrepancy_length > discrepancy_only.shape[1]:
|
||||
discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
|
||||
assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
|
||||
assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
|
||||
|
||||
remove_from_pkv = discrepancy_length
|
||||
|
||||
if new_tokens_only.shape[1] > 0:
|
||||
assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
|
||||
else:
|
||||
# edge case: in case of no intersection between prompt and new_assistant_ids
|
||||
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
|
||||
|
||||
else:
|
||||
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
||||
self.prev_target_ids = input_ids
|
||||
|
||||
self.prev_assistant_ids = assistant_input_ids
|
||||
new_cur_len = assistant_input_ids.shape[-1]
|
||||
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
||||
|
||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||||
if has_past_key_values:
|
||||
new_cache_size = new_cur_len - 1 - remove_from_pkv
|
||||
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||||
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
||||
) # the assistant does not have the token after the last match, hence the -1
|
||||
|
||||
self.assistant_kwargs = _prepare_attention_mask(
|
||||
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
|
||||
)
|
||||
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
|
||||
|
||||
# 2. Forecast next N tokens using the assistant model.
|
||||
assistant_generation_kwargs = {
|
||||
self.input_ids_key: assistant_input_ids,
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"generation_config": self.generation_config,
|
||||
"logits_processor": self.logits_processor,
|
||||
}
|
||||
|
||||
self.assistant_kwargs.pop("attention_mask", None)
|
||||
|
||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
||||
|
||||
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
||||
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
||||
|
||||
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
||||
assistant_output.sequences[:, start_assistant_look_index:],
|
||||
source_tokenizer=self.assistant_tokenizer,
|
||||
destination_tokenizer=self.target_tokenizer,
|
||||
)
|
||||
target_prompt_use_length = new_target_ids_from_window.shape[1]
|
||||
|
||||
target_prompt_use = input_ids[:, -target_prompt_use_length:]
|
||||
|
||||
_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
target_prompt_use, new_target_ids_from_window
|
||||
)
|
||||
|
||||
new_target_ids = input_ids
|
||||
|
||||
if target_new_tokens_only is not None:
|
||||
if target_new_tokens_only.shape[1] > 0:
|
||||
new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
|
||||
else:
|
||||
# edge case: in case of no intersection between prompt and new_target_ids
|
||||
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
|
||||
|
||||
self.prev_target_ids = input_ids
|
||||
|
||||
if hasattr(self.generation_config, "max_length"):
|
||||
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
||||
|
||||
# 3. Update variables for the next round of candidate generation
|
||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||
self.prev_tokens = assistant_output.sequences
|
||||
|
||||
# 4. Prepare variables for output
|
||||
if input_ids.shape[1] >= new_target_ids.shape[1]:
|
||||
return input_ids, None
|
||||
|
||||
return new_target_ids, None
|
||||
|
||||
|
||||
class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||
"""
|
||||
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
||||
|
@ -51,6 +51,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import (
|
||||
AssistedCandidateGenerator,
|
||||
AssistedCandidateGeneratorDifferentTokenizers,
|
||||
CandidateGenerator,
|
||||
PromptLookupCandidateGenerator,
|
||||
_crop_past_key_values,
|
||||
@ -617,7 +618,7 @@ class GenerationMixin:
|
||||
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@ -787,11 +788,15 @@ class GenerationMixin:
|
||||
inputs_tensor: torch.Tensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
logits_processor: LogitsProcessorList,
|
||||
target_tokenizer: "PreTrainedTokenizerBase",
|
||||
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||
model_kwargs: Dict,
|
||||
) -> CandidateGenerator:
|
||||
"""
|
||||
Returns the candidate generator to be used in `assisted_generation`
|
||||
"""
|
||||
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
|
||||
|
||||
if generation_config.prompt_lookup_num_tokens is not None:
|
||||
candidate_generator = PromptLookupCandidateGenerator(
|
||||
eos_token_id=generation_config._eos_token_tensor,
|
||||
@ -799,6 +804,17 @@ class GenerationMixin:
|
||||
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
elif different_tokenizers:
|
||||
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
|
||||
input_ids=input_ids,
|
||||
assistant_model=assistant_model,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
inputs_tensor=inputs_tensor,
|
||||
logits_processor=logits_processor,
|
||||
target_tokenizer=target_tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
)
|
||||
else:
|
||||
candidate_generator = AssistedCandidateGenerator(
|
||||
input_ids=input_ids,
|
||||
@ -1250,7 +1266,7 @@ class GenerationMixin:
|
||||
f"names: {terminations_with_generation_support}."
|
||||
)
|
||||
|
||||
def _validate_assistant(self, assistant_model):
|
||||
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
|
||||
if assistant_model is None:
|
||||
return
|
||||
|
||||
@ -1266,8 +1282,19 @@ class GenerationMixin:
|
||||
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
|
||||
)
|
||||
|
||||
if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
|
||||
raise ValueError("Make sure the main and assistant model use the same tokenizer")
|
||||
doc_reference = (
|
||||
"(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)"
|
||||
)
|
||||
if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
|
||||
if assistant_tokenizer is not None:
|
||||
raise ValueError(
|
||||
f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}."
|
||||
)
|
||||
else:
|
||||
if tokenizer is None or assistant_tokenizer is None:
|
||||
raise ValueError(
|
||||
f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}."
|
||||
)
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
@ -1923,12 +1950,15 @@ class GenerationMixin:
|
||||
- [`~generation.GenerateEncoderDecoderOutput`],
|
||||
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
||||
"""
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
self._validate_assistant(assistant_model)
|
||||
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
if synced_gpus is None:
|
||||
@ -2110,6 +2140,8 @@ class GenerationMixin:
|
||||
inputs_tensor=inputs_tensor,
|
||||
assistant_model=assistant_model,
|
||||
logits_processor=logits_processor,
|
||||
target_tokenizer=tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
@ -4138,7 +4170,7 @@ class GenerationMixin:
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
candidate_input_ids = candidate_input_ids.to(self.device)
|
||||
|
||||
if candidate_logits is not None:
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
|
||||
|
@ -88,6 +88,7 @@ if is_torch_available():
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
)
|
||||
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
|
||||
from transformers.generation.utils import _speculative_sampling
|
||||
|
||||
|
||||
@ -3510,6 +3511,34 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
def test_speculative_decoding_equals_regular_decoding(self):
|
||||
draft_name = "double7/vicuna-68m"
|
||||
target_name = "Qwen/Qwen2-0.5B-Instruct"
|
||||
|
||||
draft_model = AutoModelForCausalLM.from_pretrained(draft_name)
|
||||
target_model = AutoModelForCausalLM.from_pretrained(target_name)
|
||||
|
||||
assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name)
|
||||
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
|
||||
|
||||
prompt_size = torch.randint(low=20, high=100, size=(1,))
|
||||
max_new_tokens = torch.randint(low=10, high=50, size=(1,))
|
||||
input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50
|
||||
|
||||
max_new_tokens_item = max_new_tokens[0].item()
|
||||
expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item)
|
||||
predicted_out = target_model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_new_tokens_item,
|
||||
assistant_model=draft_model,
|
||||
target_tokenizer=target_tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
)
|
||||
|
||||
self.assertEqual(expected_out.shape, predicted_out.shape)
|
||||
self.assertTrue((expected_out == predicted_out).all().item())
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_generate_with_static_cache_multi_gpu(self):
|
||||
@ -3884,3 +3913,41 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
# bos_token_id is required when no input ids nor inputs_embeds is passed
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(max_length=20, bos_token_id=None)
|
||||
|
||||
|
||||
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
||||
def test_no_intersection(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[4, 5, 6]])
|
||||
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
|
||||
self.assertEqual(result, (None, None, None))
|
||||
|
||||
def test_complete_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_partial_overlap(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
||||
def test_no_new_tokens(self):
|
||||
prompt = np.array([[1, 2, 3]])
|
||||
prompt_plus_new_tokens = np.array([[1, 2, 3]])
|
||||
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
|
||||
prompt, prompt_plus_new_tokens
|
||||
)
|
||||
self.assertEqual(discrep_length, 0)
|
||||
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
|
||||
np.testing.assert_array_equal(discrep_only, np.array([[]]))
|
||||
|
Loading…
Reference in New Issue
Block a user