prune LM Head for USD (#36695)

* initial commit

* fix

* fix style

* set default to prune

* add tests

* comment

* remove prune flag from generate

* address Joao's comments

* deprecate_kwarg

* add doc

* fix target_vocab_size

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* fix deprecated argument assistant_model_device

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Jonathan Mamou 2025-04-08 18:44:10 +03:00 committed by GitHub
parent 4321b0648c
commit 121f91d36c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 173 additions and 39 deletions

View File

@ -19,7 +19,9 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from ..pytorch_utils import prune_linear_layer
from ..utils import is_sklearn_available
@ -36,6 +38,8 @@ if TYPE_CHECKING:
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .configuration_utils import GenerationConfig
from ..utils.deprecation import deprecate_kwarg
class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
@ -612,6 +616,63 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
return new_target_ids
class _PruneReindexingLMHead(nn.Module):
"""
A class to prune and reindex the language model head.
This class prunes the language model head to only include the specified token IDs and reindexes the logits
to map back to the original vocabulary.
Args:
original_lm_head (nn.Module): The original language model head.
token_ids (list[int]): The list of token IDs to keep.
"""
def __init__(self, original_lm_head, assistant_overlap_token_ids):
super().__init__()
self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(
original_lm_head.weight.dtype
)
def forward(self, hidden_states):
pruned_logits = self.pruned_lm_head(hidden_states)
return pruned_logits
class _MapInputEmbedding(nn.Module):
def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids):
"""
Wraps an existing embedding layer and remaps token IDs before lookup.
Args:
original_embedding (nn.Embedding): Pre-trained or existing embedding layer.
assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs.
Example: {old_id: new_id}
"""
super().__init__()
self.original_embedding = original_embedding
self.weight = original_embedding.weight
self.assistant_overlap_token_ids = assistant_overlap_token_ids
self.map = False
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
"""
Args:
input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len).
Returns:
torch.FloatTensor: Corresponding input embeddings.
"""
if self.map:
# Get the last item from input_ids
my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0)
else:
self.map = True
my_input_ids = input_ids
return self.original_embedding(my_input_ids)
class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
@ -625,36 +686,74 @@ class AssistantToTargetTranslator:
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
assistant_model_device (`str`, defaults to "cpu"):
The device where the assistant model is located. Used for placing tensors.
target_vocab_size (`int`, *optional*):
target_vocab_size (`int`):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
assistant_model_device (str, optional): The device on which the assistant model is loaded.
Defaults to "cpu".
assistant_model_device (`str`, defaults to "cpu"): The device where the assistant model is located. Used for placing tensors.
assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility.
assistant_prune_lm_head (bool): Whether to prune the assistant model's language model
head to match the target vocabulary. This is only applicable if `assistant_model` is provided.
Defaults to False for backward compatibility.
"""
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
@deprecate_kwarg("assistant_model_device", version="4.53")
def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
assistant_model_device: str = "cpu",
assistant_model: Optional["PreTrainedModel"] = None,
assistant_prune_lm_head: bool = False,
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device: str = assistant_model_device
self._assistant_model_device: str = (
assistant_model_device if assistant_model is None else assistant_model.device
)
self.target_vocab_size: int = target_vocab_size
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None
if len(self._suppress_input_ids) > 0:
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)
# the assistant vocab is not a subset of the target vocab
if self.assistant_prune_lm_head:
self.assistant_overlap_token_ids = torch.tensor(
list(self.target_to_assistant_input_ids.values()),
dtype=torch.long,
device=self._assistant_model_device,
)
original_lm_head = assistant_model.get_output_embeddings()
pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids)
del original_lm_head
assistant_model.set_output_embeddings(pruned_lm_head)
original_input_embeddings = assistant_model.get_input_embeddings()
map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids)
del original_input_embeddings
assistant_model.set_input_embeddings(map_input_embeddings)
self.map_input_embeddings = map_input_embeddings
else:
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)
def unmap_input_ids(self):
"""
Disables the mapping of input ids despite the assistant pruning for the language model head being enabled.
This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping.
"""
if self.assistant_prune_lm_head:
self.map_input_embeddings.map = False
def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
@ -710,7 +809,12 @@ class AssistantToTargetTranslator:
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
# Get last `num_new_tokens` candidate IDs
last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:]
if self.assistant_prune_lm_head:
# Map assistant IDs -> target input IDs
last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids]
transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
@ -726,10 +830,12 @@ class AssistantToTargetTranslator:
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
if self.assistant_prune_lm_head:
target_logits[..., target_logits_supported_indices] = assistant_logits
else:
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
return target_logits
@ -742,12 +848,15 @@ class AssistantVocabTranslatorCache:
_cache = weakref.WeakKeyDictionary()
@classmethod
@deprecate_kwarg("assistant_model_device", version="4.53")
def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
assistant_model_device: str = "cpu",
assistant_model: Optional["PreTrainedModel"] = None,
assistant_prune_lm_head: bool = False,
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
@ -757,7 +866,12 @@ class AssistantVocabTranslatorCache:
mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
target_tokenizer,
assistant_tokenizer,
target_vocab_size,
assistant_model_device,
assistant_model,
assistant_prune_lm_head,
)
assistant_dict[assistant_tokenizer] = mapping
@ -894,7 +1008,7 @@ class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentT
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
self._atm_translator.unmap_input_ids()
return assistant_input_ids, len(assistant_new_ids[0])

View File

@ -962,8 +962,14 @@ class GenerationMixin:
elif different_tokenizers:
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device
target_tokenizer,
assistant_tokenizer,
self.config.vocab_size,
assistant_model=assistant_model,
assistant_prune_lm_head=True, # prune LM head of assistant model
)
# Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index
assistant_model.generation_config.repetition_penalty = None
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,

View File

@ -20,6 +20,7 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
# Create mock tokenizers with predefined vocabularies
self.target_tokenizer = MagicMock()
self.assistant_tokenizer = MagicMock()
self.assistant_model = MagicMock(device=torch_device)
# Define mock vocabularies for the tokenizers
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
@ -27,15 +28,15 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
self.target_tokenizer.get_vocab.return_value = self.target_vocab
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
self.assistant_model_device = torch_device
self.target_vocab_size = 6
# Instantiate the class under test
self.translator = AssistantToTargetTranslator(
target_tokenizer=self.target_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
def test_get_assistant_to_target_input_ids(self):
@ -53,19 +54,19 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
def test_get_target_ids(self):
"""Test the translation of assistant candidate IDs to target candidate IDs."""
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
self.assistant_model_device
self.assistant_model.device
) # 'hello world foo' in assistant tokenizer
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
self.assistant_model_device
self.assistant_model.device
) # 'hello world foo' in target tokenizer
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
self.assistant_model_device
self.assistant_model.device
) # 'hello world foo baz' in assistant tokenizer
expected_target_ids = torch.LongTensor(
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
).to(
self.assistant_model_device
self.assistant_model.device
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
actual_target_ids = self.translator.get_target_ids(
@ -77,12 +78,12 @@ class TestAssistantToTargetTranslator(unittest.TestCase):
"""Test the conversion of assistant logits to target logits."""
# Assistant logits for IDs 0, 1, 2
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
self.assistant_model_device
self.assistant_model.device
) # Shape (1, 1, 5)
# Expected target logits (target_vocab_size = 4)
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
self.assistant_model_device
self.assistant_model.device
)
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
expected_target_logits[0, 0, 1] = 0.2 # 'world'
@ -119,7 +120,8 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
self.assistant_model_device = torch_device
self.assistant_model = MagicMock(device=torch_device)
self.target_vocab_size = 6
def test_same_instance_for_same_tokenizers(self):
@ -127,14 +129,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
self.assertIs(translator1, translator2, "Translators should be cached and identical")
@ -143,14 +147,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.other_target_tokenizer,
self.other_assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
@ -164,8 +170,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
@ -192,8 +199,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase):
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
assistant_model=self.assistant_model,
assistant_prune_lm_head=False,
)
# Create weak references before returning
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
@ -239,16 +247,18 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
if self.assistant_tokenizer.pad_token_id is None:
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
if self.target_tokenizer.bos_token_id is None:
if self.assistant_tokenizer.bos_token_id is None:
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
self.model_kwargs = {
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
}
atm_translator = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device
target_tokenizer=self.target_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
assistant_model=self.assistant_model,
target_vocab_size=self.target_config.vocab_size,
)
self.generator = UniversalSpeculativeDecodingGenerator(
input_ids=self.input_ids,
@ -286,7 +296,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
)
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
candidates, _ = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)
def test_speculation_depth(self):
@ -296,7 +306,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
for depth in [1, 8, 17]:
self.generator.num_assistant_tokens = depth
candidates, scores = self.generator.get_candidates(input_ids)
candidates, _ = self.generator.get_candidates(input_ids)
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
def test_device_consistency(self):
@ -310,10 +320,6 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
prompt = "Test text"
pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name)
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
usd_text = pipe_usd_output[0]["generated_text"]
pipe_vanilla = pipeline(
"text-generation",
model=cls.target_name,
@ -321,5 +327,13 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase):
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
vanilla_text = pipe_vanilla_output[0]["generated_text"]
pipe_usd = pipeline(
"text-generation",
model=cls.target_name,
assistant_model=cls.assistant_name,
)
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
usd_text = pipe_usd_output[0]["generated_text"]
# Assert that the outputs match
cls.assertEqual(usd_text, vanilla_text)