From 121f91d36c171b67c62320507dfaa460eab7657c Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Tue, 8 Apr 2025 18:44:10 +0300 Subject: [PATCH] prune LM Head for USD (#36695) * initial commit * fix * fix style * set default to prune * add tests * comment * remove prune flag from generate * address Joao's comments * deprecate_kwarg * add doc * fix target_vocab_size * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * fix deprecated argument assistant_model_device --------- Co-authored-by: Joao Gante --- .../generation/candidate_generator.py | 142 ++++++++++++++++-- src/transformers/generation/utils.py | 8 +- tests/generation/test_candidate_generator.py | 62 +++++--- 3 files changed, 173 insertions(+), 39 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index fe57f532e68..3425a0234b4 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -19,7 +19,9 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import numpy as np import torch +import torch.nn as nn +from ..pytorch_utils import prune_linear_layer from ..utils import is_sklearn_available @@ -36,6 +38,8 @@ if TYPE_CHECKING: from ..tokenization_utils_base import PreTrainedTokenizerBase from .configuration_utils import GenerationConfig +from ..utils.deprecation import deprecate_kwarg + class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" @@ -612,6 +616,63 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): return new_target_ids +class _PruneReindexingLMHead(nn.Module): + """ + A class to prune and reindex the language model head. + + This class prunes the language model head to only include the specified token IDs and reindexes the logits + to map back to the original vocabulary. + + Args: + original_lm_head (nn.Module): The original language model head. + token_ids (list[int]): The list of token IDs to keep. + """ + + def __init__(self, original_lm_head, assistant_overlap_token_ids): + super().__init__() + self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to( + original_lm_head.weight.dtype + ) + + def forward(self, hidden_states): + pruned_logits = self.pruned_lm_head(hidden_states) + return pruned_logits + + +class _MapInputEmbedding(nn.Module): + def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): + """ + Wraps an existing embedding layer and remaps token IDs before lookup. + + Args: + original_embedding (nn.Embedding): Pre-trained or existing embedding layer. + assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs. + Example: {old_id: new_id} + """ + super().__init__() + self.original_embedding = original_embedding + self.weight = original_embedding.weight + self.assistant_overlap_token_ids = assistant_overlap_token_ids + self.map = False + + def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + """ + Args: + input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len). + + Returns: + torch.FloatTensor: Corresponding input embeddings. + """ + if self.map: + # Get the last item from input_ids + my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) + else: + self.map = True + my_input_ids = input_ids + + return self.original_embedding(my_input_ids) + + class AssistantToTargetTranslator: """ Translates token ids and logits between assistant and target model vocabularies. This class is used to handle @@ -625,36 +686,74 @@ class AssistantToTargetTranslator: The tokenizer used by the target (main) model. assistant_tokenizer (`PreTrainedTokenizerBase`): The tokenizer used by the assistant model. - assistant_model_device (`str`, defaults to "cpu"): - The device where the assistant model is located. Used for placing tensors. - target_vocab_size (`int`, *optional*): + target_vocab_size (`int`): The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. + assistant_model_device (str, optional): The device on which the assistant model is loaded. + Defaults to "cpu". + assistant_model_device (`str`, defaults to "cpu"): The device where the assistant model is located. Used for placing tensors. + assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility. + assistant_prune_lm_head (bool): Whether to prune the assistant model's language model + head to match the target vocabulary. This is only applicable if `assistant_model` is provided. + Defaults to False for backward compatibility. """ FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. + @deprecate_kwarg("assistant_model_device", version="4.53") def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() assistant_model_device: str = "cpu", + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device: str = assistant_model_device + self._assistant_model_device: str = ( + assistant_model_device if assistant_model is None else assistant_model.device + ) self.target_vocab_size: int = target_vocab_size self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( self._get_assistant_to_target_input_ids() ) self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None + self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None if len(self._suppress_input_ids) > 0: - # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab - self.logits_processors = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] - ) + # the assistant vocab is not a subset of the target vocab + if self.assistant_prune_lm_head: + self.assistant_overlap_token_ids = torch.tensor( + list(self.target_to_assistant_input_ids.values()), + dtype=torch.long, + device=self._assistant_model_device, + ) + original_lm_head = assistant_model.get_output_embeddings() + pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) + del original_lm_head + assistant_model.set_output_embeddings(pruned_lm_head) + + original_input_embeddings = assistant_model.get_input_embeddings() + map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) + del original_input_embeddings + assistant_model.set_input_embeddings(map_input_embeddings) + self.map_input_embeddings = map_input_embeddings + else: + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) + + def unmap_input_ids(self): + """ + Disables the mapping of input ids despite the assistant pruning for the language model head being enabled. + + This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping. + + """ + if self.assistant_prune_lm_head: + self.map_input_embeddings.map = False def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() @@ -710,7 +809,12 @@ class AssistantToTargetTranslator: if num_new_tokens == 0: return target_input_ids else: - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + # Get last `num_new_tokens` candidate IDs + last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:] + if self.assistant_prune_lm_head: + # Map assistant IDs -> target input IDs + last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids] + transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids] return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: @@ -726,10 +830,12 @@ class AssistantToTargetTranslator: assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID # Exclude invalid indices target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] - valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] - - target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] + if self.assistant_prune_lm_head: + target_logits[..., target_logits_supported_indices] = assistant_logits + else: + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] return target_logits @@ -742,12 +848,15 @@ class AssistantVocabTranslatorCache: _cache = weakref.WeakKeyDictionary() @classmethod + @deprecate_kwarg("assistant_model_device", version="4.53") def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, assistant_model_device: str = "cpu", + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -757,7 +866,12 @@ class AssistantVocabTranslatorCache: mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device + target_tokenizer, + assistant_tokenizer, + target_vocab_size, + assistant_model_device, + assistant_model, + assistant_prune_lm_head, ) assistant_dict[assistant_tokenizer] = mapping @@ -894,7 +1008,7 @@ class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentT self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(dtype=torch.long) - + self._atm_translator.unmap_input_ids() return assistant_input_ids, len(assistant_new_ids[0]) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 09cbc9c446b..4ea1f88136d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -962,8 +962,14 @@ class GenerationMixin: elif different_tokenizers: if generation_config.do_sample is True: atm_translator = AssistantVocabTranslatorCache.get_translator( - target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device + target_tokenizer, + assistant_tokenizer, + self.config.vocab_size, + assistant_model=assistant_model, + assistant_prune_lm_head=True, # prune LM head of assistant model ) + # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index + assistant_model.generation_config.repetition_penalty = None candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, assistant_model=assistant_model, diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 38df48ab08d..3a50a963a9a 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -20,6 +20,7 @@ class TestAssistantToTargetTranslator(unittest.TestCase): # Create mock tokenizers with predefined vocabularies self.target_tokenizer = MagicMock() self.assistant_tokenizer = MagicMock() + self.assistant_model = MagicMock(device=torch_device) # Define mock vocabularies for the tokenizers self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3} @@ -27,15 +28,15 @@ class TestAssistantToTargetTranslator(unittest.TestCase): self.target_tokenizer.get_vocab.return_value = self.target_vocab self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab - self.assistant_model_device = torch_device self.target_vocab_size = 6 # Instantiate the class under test self.translator = AssistantToTargetTranslator( target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) def test_get_assistant_to_target_input_ids(self): @@ -53,19 +54,19 @@ class TestAssistantToTargetTranslator(unittest.TestCase): def test_get_target_ids(self): """Test the translation of assistant candidate IDs to target candidate IDs.""" assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo' in assistant tokenizer target_input_ids = torch.LongTensor([[0, 1, 2]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo' in target tokenizer assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo baz' in assistant tokenizer expected_target_ids = torch.LongTensor( [[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] ).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab) actual_target_ids = self.translator.get_target_ids( @@ -77,12 +78,12 @@ class TestAssistantToTargetTranslator(unittest.TestCase): """Test the conversion of assistant logits to target logits.""" # Assistant logits for IDs 0, 1, 2 assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to( - self.assistant_model_device + self.assistant_model.device ) # Shape (1, 1, 5) # Expected target logits (target_vocab_size = 4) expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to( - self.assistant_model_device + self.assistant_model.device ) expected_target_logits[0, 0, 0] = 0.1 # 'hello' expected_target_logits[0, 0, 1] = 0.2 # 'world' @@ -119,7 +120,8 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase): self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) - self.assistant_model_device = torch_device + self.assistant_model = MagicMock(device=torch_device) + self.target_vocab_size = 6 def test_same_instance_for_same_tokenizers(self): @@ -127,14 +129,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) self.assertIs(translator1, translator2, "Translators should be cached and identical") @@ -143,14 +147,16 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.other_target_tokenizer, self.other_assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") @@ -164,8 +170,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) @@ -192,8 +199,9 @@ class TestAssistantVocabTranslatorCache(unittest.TestCase): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model_device=self.assistant_model_device, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) # Create weak references before returning refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) @@ -239,16 +247,18 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase): self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id if self.assistant_tokenizer.pad_token_id is None: self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id - if self.target_tokenizer.bos_token_id is None: + if self.assistant_tokenizer.bos_token_id is None: self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) self.model_kwargs = { "attention_mask": torch.ones_like(self.input_ids).to(torch_device), } - atm_translator = AssistantVocabTranslatorCache.get_translator( - self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device + target_tokenizer=self.target_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + assistant_model=self.assistant_model, + target_vocab_size=self.target_config.vocab_size, ) self.generator = UniversalSpeculativeDecodingGenerator( input_ids=self.input_ids, @@ -286,7 +296,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase): ) input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) + candidates, _ = self.generator.get_candidates(input_ids) self.assertIsNotNone(candidates) def test_speculation_depth(self): @@ -296,7 +306,7 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase): for depth in [1, 8, 17]: self.generator.num_assistant_tokens = depth - candidates, scores = self.generator.get_candidates(input_ids) + candidates, _ = self.generator.get_candidates(input_ids) self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) def test_device_consistency(self): @@ -310,10 +320,6 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase): """Test that USD matches vanilla sampling with temperature set to nearly 0""" prompt = "Test text" - pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name) - pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature - usd_text = pipe_usd_output[0]["generated_text"] - pipe_vanilla = pipeline( "text-generation", model=cls.target_name, @@ -321,5 +327,13 @@ class TestUniversalSpeculativeDecoding(unittest.TestCase): pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) vanilla_text = pipe_vanilla_output[0]["generated_text"] + pipe_usd = pipeline( + "text-generation", + model=cls.target_name, + assistant_model=cls.assistant_name, + ) + pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature + usd_text = pipe_usd_output[0]["generated_text"] + # Assert that the outputs match cls.assertEqual(usd_text, vanilla_text)