Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)

* Update candidate_generator.py

* Update utils.py

* add lookbehind params to _get_candidate_generator

* make fixup

* add unit tests

* fix failing tests

* add docstrings

* fix docstrings; remove non-optimized AnyTokenizer

* added any tokenizer generation correctness test

* make fixup

* fix assertion syntax

* PR review fixes

* address additional PR comments

* fix tests

* remove stropping criteria arg

* make fixup

* add AssistantConfig

* fix prev_tokens branching

* pass tokenizers through `generate()`kwargs

* fix lookbehind values; tokenizer params WIP

* fixup

* AssistantConfig

* remove AssistantConfig; apply PR suggestions

* restructure tests

* fixup

* fix assistant_tokenizer arg validation

* fixup

* fix tests in TestAssistedCandidateGeneratorDifferentTokenizers

* fix class docstring

* PR suggestions

* doc

* doc update and improvements to `_validate_assistant()`

---------

Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
Daniel Korat 2024-10-10 15:41:53 +03:00 committed by GitHub
parent dda3f91d06
commit fb0c6b521d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 440 additions and 10 deletions

View File

@ -408,14 +408,24 @@ For the complete list of the available parameters, refer to the [API documentati
### Speculative Decoding
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
`do_sample=True`, then the token validation with resampling introduced in the
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
assistant model (ideally a much smaller one), to generate a few candidate tokens. The main model then validates the candidate
tokens in a single forward pass, which speeds up the decoding process. If `do_sample=True`, then the token validation with
resampling introduced in the [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
Assisted decoding assumes the main and assistant models have the same tokenizer, otherwise, see Universal Assisted Decoding below.
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
#### Universal Assisted Decoding
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.
To enable assisted decoding, set the `assistant_model` argument with a model.
```python
@ -435,6 +445,26 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "google/gemma-2-9b"
>>> assistant_checkpoint = "double7/vicuna-68m"
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
@ -458,6 +488,7 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
### DoLa Decoding
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the

View File

@ -16,6 +16,7 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import numpy as np
import torch
from ..cache_utils import DynamicCache
@ -25,6 +26,7 @@ from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .configuration_utils import GenerationConfig
@ -156,6 +158,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare generation-related options.
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
@ -258,6 +261,303 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""
`CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
for the assistant and main models. This class generates candidates through the use of a smaller
model.
The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model.
target_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for the target model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for the assistant model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
model_kwargs (`Dict`):
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
"""
def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_tokens = None
self.prev_assistant_ids = None
self.target_lookbehind = 10
self.assistant_lookbehind = 10
@staticmethod
def _get_longest_diag_dict(input_matrix, nonzero_idx):
"""
Calculates the length of the longest diagonal sequence in a given matrix.
Args:
input_matrix (torch.Tensor): The input matrix.
nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
Returns:
dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
"""
visited = set()
diags = {}
for idx in nonzero_idx:
start_idx = torch.clone(idx)
tuple_start_idx = tuple(start_idx.tolist())
if tuple_start_idx in visited:
continue
visited.add(tuple_start_idx)
cur_diag_len = 1
start_idx += 1
while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
tuple_start_idx = tuple(start_idx.tolist())
visited.add(tuple_start_idx)
if input_matrix[start_idx[0], start_idx[1]] == 1:
cur_diag_len += 1
start_idx += 1
else:
break
diags[idx] = cur_diag_len
return diags
@staticmethod
def _get_longest_diag_index(input_matrix):
"""
Returns the start index and length of the longest diagonal in the given input.
Args:
input_matrix (numpy.ndarray): The input matrix.
Returns:
tuple: A tuple containing the start index and length of the longest diagonal.
"""
diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
input_matrix, input_matrix.nonzero()
)
diags_values = list(diags.values())
diags_keys = list(diags.keys())
best_diag = np.argmax(diags_values)
diag_start_index = diags_keys[best_diag]
diag_start_length = diags_values[best_diag]
return diag_start_index, diag_start_length
@staticmethod
def _get_tokens_diag(prompt, prompt_plus_new_tokens):
"""
Input:
prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
Output:
discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
"""
compare_mat = prompt_plus_new_tokens.T == prompt
if not torch.is_tensor(compare_mat):
compare_mat = torch.tensor(compare_mat)
compare_mat_int = compare_mat.to(int)
if not compare_mat_int.any().item():
# empty intersection between prompt and prompt_plus_new_tokens
return None, None, None
longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
compare_mat_int
)
new_token_start_index = longest_location[0] + longest_diag_length
discrepancy_with_old = longest_location[1] + longest_diag_length
discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
discrepancy_only = prompt_plus_new_tokens[
:, new_token_start_index : new_token_start_index + discrepancy_length
]
return discrepancy_length, new_tokens_only, discrepancy_only
def convert_source_tokens_to_target_tokens(
self,
input_ids,
source_tokenizer,
destination_tokenizer,
):
"""
Convert token IDs from one tokenizer to another.
Args:
input_ids: The input token IDs.
source_tokenizer: The source tokenizer.
destination_tokenizer: The destination tokenizer.
Returns:
The converted token IDs.
"""
text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
return dest_ids.to(input_ids.device)
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
max_new_tokens = int(self.num_assistant_tokens)
if max_new_tokens == 0:
return input_ids, None
input_ids = input_ids.to(self.assistant_model.device)
convert_kwargs = {
"source_tokenizer": self.target_tokenizer,
"destination_tokenizer": self.assistant_tokenizer,
}
remove_from_pkv = 0
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
# (one for each conversion) which mark where to start looking for the overlap between the
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
)
prompt_use_length = new_assistant_ids.shape[1]
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
discrepancy_length, new_tokens_only, discrepancy_only = (
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
)
assistant_input_ids = self.prev_assistant_ids
if new_tokens_only is not None:
if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
if discrepancy_length == discrepancy_only.shape[1]:
assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
elif discrepancy_length > discrepancy_only.shape[1]:
discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
remove_from_pkv = discrepancy_length
if new_tokens_only.shape[1] > 0:
assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
else:
# edge case: in case of no intersection between prompt and new_assistant_ids
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids = input_ids
self.prev_assistant_ids = assistant_input_ids
new_cur_len = assistant_input_ids.shape[-1]
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: assistant_input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
self.assistant_kwargs.pop("attention_mask", None)
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
num_prev_assistant = self.prev_assistant_ids.shape[1]
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
assistant_output.sequences[:, start_assistant_look_index:],
source_tokenizer=self.assistant_tokenizer,
destination_tokenizer=self.target_tokenizer,
)
target_prompt_use_length = new_target_ids_from_window.shape[1]
target_prompt_use = input_ids[:, -target_prompt_use_length:]
_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
target_prompt_use, new_target_ids_from_window
)
new_target_ids = input_ids
if target_new_tokens_only is not None:
if target_new_tokens_only.shape[1] > 0:
new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
else:
# edge case: in case of no intersection between prompt and new_target_ids
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
self.prev_target_ids = input_ids
if hasattr(self.generation_config, "max_length"):
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_tokens = assistant_output.sequences
# 4. Prepare variables for output
if input_ids.shape[1] >= new_target_ids.shape[1]:
return input_ids, None
return new_target_ids, None
class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up

View File

@ -51,6 +51,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
@ -617,7 +618,7 @@ class GenerationMixin:
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
return model_kwargs
@ -787,11 +788,15 @@ class GenerationMixin:
inputs_tensor: torch.Tensor,
assistant_model: "PreTrainedModel",
logits_processor: LogitsProcessorList,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
model_kwargs: Dict,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
"""
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
if generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=generation_config._eos_token_tensor,
@ -799,6 +804,17 @@ class GenerationMixin:
max_matching_ngram_size=generation_config.max_matching_ngram_size,
max_length=generation_config.max_length,
)
elif different_tokenizers:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
@ -1250,7 +1266,7 @@ class GenerationMixin:
f"names: {terminations_with_generation_support}."
)
def _validate_assistant(self, assistant_model):
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
if assistant_model is None:
return
@ -1266,8 +1282,19 @@ class GenerationMixin:
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
)
if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
raise ValueError("Make sure the main and assistant model use the same tokenizer")
doc_reference = (
"(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)"
)
if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
if assistant_tokenizer is not None:
raise ValueError(
f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}."
)
else:
if tokenizer is None or assistant_tokenizer is None:
raise ValueError(
f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}."
)
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
@ -1923,12 +1950,15 @@ class GenerationMixin:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
# 2. Set generation parameters if not already defined
if synced_gpus is None:
@ -2110,6 +2140,8 @@ class GenerationMixin:
inputs_tensor=inputs_tensor,
assistant_model=assistant_model,
logits_processor=logits_processor,
target_tokenizer=tokenizer,
assistant_tokenizer=assistant_tokenizer,
model_kwargs=model_kwargs,
)
@ -4138,7 +4170,7 @@ class GenerationMixin:
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)

View File

@ -88,6 +88,7 @@ if is_torch_available():
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
from transformers.generation.utils import _speculative_sampling
@ -3510,6 +3511,34 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
def test_speculative_decoding_equals_regular_decoding(self):
draft_name = "double7/vicuna-68m"
target_name = "Qwen/Qwen2-0.5B-Instruct"
draft_model = AutoModelForCausalLM.from_pretrained(draft_name)
target_model = AutoModelForCausalLM.from_pretrained(target_name)
assistant_tokenizer = AutoTokenizer.from_pretrained(draft_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
prompt_size = torch.randint(low=20, high=100, size=(1,))
max_new_tokens = torch.randint(low=10, high=50, size=(1,))
input_ids = (torch.rand(1, prompt_size[0]) * 100).to(int) + 50
max_new_tokens_item = max_new_tokens[0].item()
expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item)
predicted_out = target_model.generate(
input_ids,
do_sample=False,
max_new_tokens=max_new_tokens_item,
assistant_model=draft_model,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
self.assertEqual(expected_out.shape, predicted_out.shape)
self.assertTrue((expected_out == predicted_out).all().item())
@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_with_static_cache_multi_gpu(self):
@ -3884,3 +3913,41 @@ class TokenHealingTestCase(unittest.TestCase):
# bos_token_id is required when no input ids nor inputs_embeds is passed
with self.assertRaises(ValueError):
model.generate(max_length=20, bos_token_id=None)
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))
def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))