mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
prune LM Head for USD (#36695)
* initial commit * fix * fix style * set default to prune * add tests * comment * remove prune flag from generate * address Joao's comments * deprecate_kwarg * add doc * fix target_vocab_size * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * fix deprecated argument assistant_model_device --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
4321b0648c
commit
121f91d36c
@ -19,7 +19,9 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..pytorch_utils import prune_linear_layer
|
||||||
from ..utils import is_sklearn_available
|
from ..utils import is_sklearn_available
|
||||||
|
|
||||||
|
|
||||||
@ -36,6 +38,8 @@ if TYPE_CHECKING:
|
|||||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
|
|
||||||
|
from ..utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
class CandidateGenerator:
|
class CandidateGenerator:
|
||||||
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
|
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
|
||||||
@ -612,6 +616,63 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
return new_target_ids
|
return new_target_ids
|
||||||
|
|
||||||
|
|
||||||
|
class _PruneReindexingLMHead(nn.Module):
|
||||||
|
"""
|
||||||
|
A class to prune and reindex the language model head.
|
||||||
|
|
||||||
|
This class prunes the language model head to only include the specified token IDs and reindexes the logits
|
||||||
|
to map back to the original vocabulary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_lm_head (nn.Module): The original language model head.
|
||||||
|
token_ids (list[int]): The list of token IDs to keep.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, original_lm_head, assistant_overlap_token_ids):
|
||||||
|
super().__init__()
|
||||||
|
self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(
|
||||||
|
original_lm_head.weight.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
pruned_logits = self.pruned_lm_head(hidden_states)
|
||||||
|
return pruned_logits
|
||||||
|
|
||||||
|
|
||||||
|
class _MapInputEmbedding(nn.Module):
|
||||||
|
def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids):
|
||||||
|
"""
|
||||||
|
Wraps an existing embedding layer and remaps token IDs before lookup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_embedding (nn.Embedding): Pre-trained or existing embedding layer.
|
||||||
|
assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs.
|
||||||
|
Example: {old_id: new_id}
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.original_embedding = original_embedding
|
||||||
|
self.weight = original_embedding.weight
|
||||||
|
self.assistant_overlap_token_ids = assistant_overlap_token_ids
|
||||||
|
self.map = False
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.FloatTensor: Corresponding input embeddings.
|
||||||
|
"""
|
||||||
|
if self.map:
|
||||||
|
# Get the last item from input_ids
|
||||||
|
my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
self.map = True
|
||||||
|
my_input_ids = input_ids
|
||||||
|
|
||||||
|
return self.original_embedding(my_input_ids)
|
||||||
|
|
||||||
|
|
||||||
class AssistantToTargetTranslator:
|
class AssistantToTargetTranslator:
|
||||||
"""
|
"""
|
||||||
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
|
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
|
||||||
@ -625,37 +686,75 @@ class AssistantToTargetTranslator:
|
|||||||
The tokenizer used by the target (main) model.
|
The tokenizer used by the target (main) model.
|
||||||
assistant_tokenizer (`PreTrainedTokenizerBase`):
|
assistant_tokenizer (`PreTrainedTokenizerBase`):
|
||||||
The tokenizer used by the assistant model.
|
The tokenizer used by the assistant model.
|
||||||
assistant_model_device (`str`, defaults to "cpu"):
|
target_vocab_size (`int`):
|
||||||
The device where the assistant model is located. Used for placing tensors.
|
|
||||||
target_vocab_size (`int`, *optional*):
|
|
||||||
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
|
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
|
||||||
|
assistant_model_device (str, optional): The device on which the assistant model is loaded.
|
||||||
|
Defaults to "cpu".
|
||||||
|
assistant_model_device (`str`, defaults to "cpu"): The device where the assistant model is located. Used for placing tensors.
|
||||||
|
assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility.
|
||||||
|
assistant_prune_lm_head (bool): Whether to prune the assistant model's language model
|
||||||
|
head to match the target vocabulary. This is only applicable if `assistant_model` is provided.
|
||||||
|
Defaults to False for backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
|
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
|
||||||
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
|
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
|
||||||
|
|
||||||
|
@deprecate_kwarg("assistant_model_device", version="4.53")
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target_tokenizer: "PreTrainedTokenizerBase",
|
target_tokenizer: "PreTrainedTokenizerBase",
|
||||||
assistant_tokenizer: "PreTrainedTokenizerBase",
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||||
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
|
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
|
||||||
assistant_model_device: str = "cpu",
|
assistant_model_device: str = "cpu",
|
||||||
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
|
assistant_prune_lm_head: bool = False,
|
||||||
):
|
):
|
||||||
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
|
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
|
||||||
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
|
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
|
||||||
self._assistant_model_device: str = assistant_model_device
|
self._assistant_model_device: str = (
|
||||||
|
assistant_model_device if assistant_model is None else assistant_model.device
|
||||||
|
)
|
||||||
self.target_vocab_size: int = target_vocab_size
|
self.target_vocab_size: int = target_vocab_size
|
||||||
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
|
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
|
||||||
self._get_assistant_to_target_input_ids()
|
self._get_assistant_to_target_input_ids()
|
||||||
)
|
)
|
||||||
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
|
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
|
||||||
self.logits_processors: Optional[LogitsProcessorList] = None
|
self.logits_processors: Optional[LogitsProcessorList] = None
|
||||||
|
self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None
|
||||||
if len(self._suppress_input_ids) > 0:
|
if len(self._suppress_input_ids) > 0:
|
||||||
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
|
# the assistant vocab is not a subset of the target vocab
|
||||||
|
if self.assistant_prune_lm_head:
|
||||||
|
self.assistant_overlap_token_ids = torch.tensor(
|
||||||
|
list(self.target_to_assistant_input_ids.values()),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self._assistant_model_device,
|
||||||
|
)
|
||||||
|
original_lm_head = assistant_model.get_output_embeddings()
|
||||||
|
pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids)
|
||||||
|
del original_lm_head
|
||||||
|
assistant_model.set_output_embeddings(pruned_lm_head)
|
||||||
|
|
||||||
|
original_input_embeddings = assistant_model.get_input_embeddings()
|
||||||
|
map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids)
|
||||||
|
del original_input_embeddings
|
||||||
|
assistant_model.set_input_embeddings(map_input_embeddings)
|
||||||
|
self.map_input_embeddings = map_input_embeddings
|
||||||
|
else:
|
||||||
self.logits_processors = LogitsProcessorList(
|
self.logits_processors = LogitsProcessorList(
|
||||||
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
|
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def unmap_input_ids(self):
|
||||||
|
"""
|
||||||
|
Disables the mapping of input ids despite the assistant pruning for the language model head being enabled.
|
||||||
|
|
||||||
|
This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.assistant_prune_lm_head:
|
||||||
|
self.map_input_embeddings.map = False
|
||||||
|
|
||||||
def _get_assistant_to_target_input_ids(self):
|
def _get_assistant_to_target_input_ids(self):
|
||||||
target_vocab = self._target_tokenizer.get_vocab()
|
target_vocab = self._target_tokenizer.get_vocab()
|
||||||
assistant_vocab = self._assistant_tokenizer.get_vocab()
|
assistant_vocab = self._assistant_tokenizer.get_vocab()
|
||||||
@ -710,7 +809,12 @@ class AssistantToTargetTranslator:
|
|||||||
if num_new_tokens == 0:
|
if num_new_tokens == 0:
|
||||||
return target_input_ids
|
return target_input_ids
|
||||||
else:
|
else:
|
||||||
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
|
# Get last `num_new_tokens` candidate IDs
|
||||||
|
last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:]
|
||||||
|
if self.assistant_prune_lm_head:
|
||||||
|
# Map assistant IDs -> target input IDs
|
||||||
|
last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids]
|
||||||
|
transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids]
|
||||||
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
|
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
|
||||||
|
|
||||||
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
|
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
@ -726,10 +830,12 @@ class AssistantToTargetTranslator:
|
|||||||
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
|
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
|
||||||
# Exclude invalid indices
|
# Exclude invalid indices
|
||||||
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
|
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
|
||||||
|
|
||||||
|
if self.assistant_prune_lm_head:
|
||||||
|
target_logits[..., target_logits_supported_indices] = assistant_logits
|
||||||
|
else:
|
||||||
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
|
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
|
||||||
|
|
||||||
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
|
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
|
||||||
|
|
||||||
return target_logits
|
return target_logits
|
||||||
|
|
||||||
|
|
||||||
@ -742,12 +848,15 @@ class AssistantVocabTranslatorCache:
|
|||||||
_cache = weakref.WeakKeyDictionary()
|
_cache = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@deprecate_kwarg("assistant_model_device", version="4.53")
|
||||||
def get_translator(
|
def get_translator(
|
||||||
cls,
|
cls,
|
||||||
target_tokenizer: "PreTrainedTokenizerBase",
|
target_tokenizer: "PreTrainedTokenizerBase",
|
||||||
assistant_tokenizer: "PreTrainedTokenizerBase",
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
||||||
target_vocab_size: int,
|
target_vocab_size: int,
|
||||||
assistant_model_device: str = "cpu",
|
assistant_model_device: str = "cpu",
|
||||||
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
|
assistant_prune_lm_head: bool = False,
|
||||||
) -> AssistantToTargetTranslator:
|
) -> AssistantToTargetTranslator:
|
||||||
assistant_dict = cls._cache.get(target_tokenizer)
|
assistant_dict = cls._cache.get(target_tokenizer)
|
||||||
if assistant_dict is None:
|
if assistant_dict is None:
|
||||||
@ -757,7 +866,12 @@ class AssistantVocabTranslatorCache:
|
|||||||
mapping = assistant_dict.get(assistant_tokenizer)
|
mapping = assistant_dict.get(assistant_tokenizer)
|
||||||
if mapping is None:
|
if mapping is None:
|
||||||
mapping = AssistantToTargetTranslator(
|
mapping = AssistantToTargetTranslator(
|
||||||
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
|
target_tokenizer,
|
||||||
|
assistant_tokenizer,
|
||||||
|
target_vocab_size,
|
||||||
|
assistant_model_device,
|
||||||
|
assistant_model,
|
||||||
|
assistant_prune_lm_head,
|
||||||
)
|
)
|
||||||
assistant_dict[assistant_tokenizer] = mapping
|
assistant_dict[assistant_tokenizer] = mapping
|
||||||
|
|
||||||
@ -894,7 +1008,7 @@ class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentT
|
|||||||
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
|
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
|
||||||
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
|
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
|
||||||
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
|
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
|
||||||
|
self._atm_translator.unmap_input_ids()
|
||||||
return assistant_input_ids, len(assistant_new_ids[0])
|
return assistant_input_ids, len(assistant_new_ids[0])
|
||||||
|
|
||||||
|
|
||||||
|
@ -962,8 +962,14 @@ class GenerationMixin:
|
|||||||
elif different_tokenizers:
|
elif different_tokenizers:
|
||||||
if generation_config.do_sample is True:
|
if generation_config.do_sample is True:
|
||||||
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device
|
target_tokenizer,
|
||||||
|
assistant_tokenizer,
|
||||||
|
self.config.vocab_size,
|
||||||
|
assistant_model=assistant_model,
|
||||||
|
assistant_prune_lm_head=True, # prune LM head of assistant model
|
||||||
)
|
)
|
||||||
|
# Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index
|
||||||
|
assistant_model.generation_config.repetition_penalty = None
|
||||||
candidate_generator = UniversalSpeculativeDecodingGenerator(
|
candidate_generator = UniversalSpeculativeDecodingGenerator(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
|
@ -20,6 +20,7 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
|||||||
# Create mock tokenizers with predefined vocabularies
|
# Create mock tokenizers with predefined vocabularies
|
||||||
self.target_tokenizer = MagicMock()
|
self.target_tokenizer = MagicMock()
|
||||||
self.assistant_tokenizer = MagicMock()
|
self.assistant_tokenizer = MagicMock()
|
||||||
|
self.assistant_model = MagicMock(device=torch_device)
|
||||||
|
|
||||||
# Define mock vocabularies for the tokenizers
|
# Define mock vocabularies for the tokenizers
|
||||||
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
||||||
@ -27,15 +28,15 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
|||||||
|
|
||||||
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
||||||
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
||||||
self.assistant_model_device = torch_device
|
|
||||||
self.target_vocab_size = 6
|
self.target_vocab_size = 6
|
||||||
|
|
||||||
# Instantiate the class under test
|
# Instantiate the class under test
|
||||||
self.translator = AssistantToTargetTranslator(
|
self.translator = AssistantToTargetTranslator(
|
||||||
target_tokenizer=self.target_tokenizer,
|
target_tokenizer=self.target_tokenizer,
|
||||||
assistant_tokenizer=self.assistant_tokenizer,
|
assistant_tokenizer=self.assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_assistant_to_target_input_ids(self):
|
def test_get_assistant_to_target_input_ids(self):
|
||||||
@ -53,19 +54,19 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
|||||||
def test_get_target_ids(self):
|
def test_get_target_ids(self):
|
||||||
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
||||||
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
) # 'hello world foo' in assistant tokenizer
|
) # 'hello world foo' in assistant tokenizer
|
||||||
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
) # 'hello world foo' in target tokenizer
|
) # 'hello world foo' in target tokenizer
|
||||||
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
) # 'hello world foo baz' in assistant tokenizer
|
) # 'hello world foo baz' in assistant tokenizer
|
||||||
|
|
||||||
expected_target_ids = torch.LongTensor(
|
expected_target_ids = torch.LongTensor(
|
||||||
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
||||||
).to(
|
).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
||||||
|
|
||||||
actual_target_ids = self.translator.get_target_ids(
|
actual_target_ids = self.translator.get_target_ids(
|
||||||
@ -77,12 +78,12 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
|
|||||||
"""Test the conversion of assistant logits to target logits."""
|
"""Test the conversion of assistant logits to target logits."""
|
||||||
# Assistant logits for IDs 0, 1, 2
|
# Assistant logits for IDs 0, 1, 2
|
||||||
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
) # Shape (1, 1, 5)
|
) # Shape (1, 1, 5)
|
||||||
|
|
||||||
# Expected target logits (target_vocab_size = 4)
|
# Expected target logits (target_vocab_size = 4)
|
||||||
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
||||||
self.assistant_model_device
|
self.assistant_model.device
|
||||||
)
|
)
|
||||||
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
||||||
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
||||||
@ -119,7 +120,8 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
|||||||
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
||||||
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
||||||
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
||||||
self.assistant_model_device = torch_device
|
self.assistant_model = MagicMock(device=torch_device)
|
||||||
|
|
||||||
self.target_vocab_size = 6
|
self.target_vocab_size = 6
|
||||||
|
|
||||||
def test_same_instance_for_same_tokenizers(self):
|
def test_same_instance_for_same_tokenizers(self):
|
||||||
@ -127,14 +129,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
|||||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||||
self.target_tokenizer,
|
self.target_tokenizer,
|
||||||
self.assistant_tokenizer,
|
self.assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||||
self.target_tokenizer,
|
self.target_tokenizer,
|
||||||
self.assistant_tokenizer,
|
self.assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
||||||
|
|
||||||
@ -143,14 +147,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
|||||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||||
self.target_tokenizer,
|
self.target_tokenizer,
|
||||||
self.assistant_tokenizer,
|
self.assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||||
self.other_target_tokenizer,
|
self.other_target_tokenizer,
|
||||||
self.other_assistant_tokenizer,
|
self.other_assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
||||||
|
|
||||||
@ -164,8 +170,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
|||||||
translator = AssistantVocabTranslatorCache.get_translator(
|
translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
target_tokenizer,
|
target_tokenizer,
|
||||||
assistant_tokenizer,
|
assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||||
|
|
||||||
@ -192,8 +199,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
|||||||
translator = AssistantVocabTranslatorCache.get_translator(
|
translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
target_tokenizer,
|
target_tokenizer,
|
||||||
assistant_tokenizer,
|
assistant_tokenizer,
|
||||||
assistant_model_device=self.assistant_model_device,
|
|
||||||
target_vocab_size=self.target_vocab_size,
|
target_vocab_size=self.target_vocab_size,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
assistant_prune_lm_head=False,
|
||||||
)
|
)
|
||||||
# Create weak references before returning
|
# Create weak references before returning
|
||||||
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
||||||
@ -239,16 +247,18 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
|||||||
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
||||||
if self.assistant_tokenizer.pad_token_id is None:
|
if self.assistant_tokenizer.pad_token_id is None:
|
||||||
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
||||||
if self.target_tokenizer.bos_token_id is None:
|
if self.assistant_tokenizer.bos_token_id is None:
|
||||||
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
||||||
|
|
||||||
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
||||||
}
|
}
|
||||||
|
|
||||||
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||||
self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device
|
target_tokenizer=self.target_tokenizer,
|
||||||
|
assistant_tokenizer=self.assistant_tokenizer,
|
||||||
|
assistant_model=self.assistant_model,
|
||||||
|
target_vocab_size=self.target_config.vocab_size,
|
||||||
)
|
)
|
||||||
self.generator = UniversalSpeculativeDecodingGenerator(
|
self.generator = UniversalSpeculativeDecodingGenerator(
|
||||||
input_ids=self.input_ids,
|
input_ids=self.input_ids,
|
||||||
@ -286,7 +296,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
||||||
self.generator.input_ids = input_ids
|
self.generator.input_ids = input_ids
|
||||||
candidates, scores = self.generator.get_candidates(input_ids)
|
candidates, _ = self.generator.get_candidates(input_ids)
|
||||||
self.assertIsNotNone(candidates)
|
self.assertIsNotNone(candidates)
|
||||||
|
|
||||||
def test_speculation_depth(self):
|
def test_speculation_depth(self):
|
||||||
@ -296,7 +306,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
|||||||
|
|
||||||
for depth in [1, 8, 17]:
|
for depth in [1, 8, 17]:
|
||||||
self.generator.num_assistant_tokens = depth
|
self.generator.num_assistant_tokens = depth
|
||||||
candidates, scores = self.generator.get_candidates(input_ids)
|
candidates, _ = self.generator.get_candidates(input_ids)
|
||||||
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
||||||
|
|
||||||
def test_device_consistency(self):
|
def test_device_consistency(self):
|
||||||
@ -310,10 +320,6 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
|||||||
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
||||||
prompt = "Test text"
|
prompt = "Test text"
|
||||||
|
|
||||||
pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name)
|
|
||||||
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
|
||||||
usd_text = pipe_usd_output[0]["generated_text"]
|
|
||||||
|
|
||||||
pipe_vanilla = pipeline(
|
pipe_vanilla = pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
model=cls.target_name,
|
model=cls.target_name,
|
||||||
@ -321,5 +327,13 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
|||||||
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
||||||
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
||||||
|
|
||||||
|
pipe_usd = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=cls.target_name,
|
||||||
|
assistant_model=cls.assistant_name,
|
||||||
|
)
|
||||||
|
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||||
|
usd_text = pipe_usd_output[0]["generated_text"]
|
||||||
|
|
||||||
# Assert that the outputs match
|
# Assert that the outputs match
|
||||||
cls.assertEqual(usd_text, vanilla_text)
|
cls.assertEqual(usd_text, vanilla_text)
|
||||||
|
Loading…
Reference in New Issue
Block a user