diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6fa96e5c8e4..3f2e02a703b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -79,10 +79,10 @@ class Cache: def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: + if self.key_cache[layer_idx].numel(): device = self.key_cache[layer_idx].device self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: + if self.value_cache[layer_idx].numel(): device = self.value_cache[layer_idx].device self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) @@ -433,12 +433,12 @@ class DynamicCache(Cache): if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) self.key_cache.append(key_states) self.value_cache.append(value_states) elif ( - len(self.key_cache[layer_idx]) == 0 + not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model ): # fills previously skipped layers; checking for tensor causes errors self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states @@ -454,7 +454,7 @@ class DynamicCache(Cache): is_empty_layer = ( len(self.key_cache) == 0 # no cache in any layer or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + or not self.key_cache[layer_idx].numel() # the layer has no cache ) layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 return layer_seq_length @@ -494,7 +494,7 @@ class DynamicCache(Cache): self._seen_tokens = max_length for idx in range(len(self.key_cache)): - if self.key_cache[idx] != []: + if self.key_cache[idx].numel(): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] @@ -516,8 +516,8 @@ class DynamicCache(Cache): `generation.utils`""" cache = cls() for idx in range(len(splits[0])): - key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] - value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []] + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] if key_cache != []: layer_keys = torch.cat(key_cache, dim=0) layer_values = torch.cat(value_cache, dim=0) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 604c35dfefc..bfb404be959 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -48,6 +48,7 @@ from ..utils import ( is_accelerate_available, is_hqq_available, is_optimum_quanto_available, + is_torchdynamo_exporting, logging, ) from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint @@ -374,6 +375,102 @@ class GenerationMixin: To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ + def _cache_dependant_input_preparation( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + Generic cache-dependent input preparation + The code is put in a separate function to allow granular unit testing + as it needs a different implementation to be exportable. + + If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + - Exception 1: when passing input_embeds, input_ids may be missing entries + - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + - Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + generate the first token for each sequence. Later use the generated Input ids for continuation. + + The current implementation does not rely on ``self`` and could be + a class method. It is left as a standard method to be easily rewritten. + """ + if is_torchdynamo_exporting(): + return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( + inputs_embeds is not None # Exception 1 + or (cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids + + def _cache_dependant_input_preparation_exporting( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + This method implements method ``_cache_dependant_input_preparation`` + with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. + The code is put in a separate function to allow granular unit testing. + """ + if inputs_embeds is None: + input_ids = input_ids[:, cache_position] + else: + # This is the code we need to implemented with torch.cond. + # if input_ids.shape[1] == 0: + # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + # else: + # if cache_position[-1] >= input_ids.shape[1]: + # input_ids = input_ids[:, -cache_position.shape[0] :] + # else: + # if input_ids.shape[1] != cache_position.shape[0]: + # input_ids = input_ids[:, cache_position] + def branch_1(inputs_embeds, cache_position): + return inputs_embeds[:, -cache_position.shape[0] :] + + def branch_2(input_ids, cache_position): + return input_ids[:, -cache_position.shape[0] :] + + def branch_3(input_ids, cache_position): + return input_ids[:, cache_position] + + inputs_embeds, input_ids = torch.cond( + input_ids.shape[1] == 0, + ( + lambda input_ids, inputs_embeds, cache_position: ( + branch_1(inputs_embeds, cache_position), + input_ids, + ) + ), + ( + lambda input_ids, inputs_embeds, cache_position: ( + inputs_embeds, + torch.cond( + cache_position[-1] >= input_ids.shape[1], + branch_2, + lambda input_ids, cache_position: ( + torch.cond( + input_ids.shape[1] != cache_position.shape[0], + branch_3, + (lambda input_ids, cache_position: input_ids), + [input_ids, cache_position], + ) + ), + [input_ids, cache_position], + ), + ) + ), + [input_ids, inputs_embeds, cache_position], + ) + return inputs_embeds, input_ids + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -404,23 +501,11 @@ class GenerationMixin: cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) # 2. Generic cache-dependent input preparation - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and - # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or cache_position[-1] >= input_ids.shape[1] # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] + inputs_embeds, input_ids = self._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) # 3. Prepare base model inputs input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" @@ -1590,6 +1675,8 @@ class GenerationMixin: generation_config = self.generation_config using_model_generation_config = True + # `torch.export.export` usually raises an exception if it is called + # with ``strict=True``. deepcopy can only be processed if ``strict=False``. generation_config = copy.deepcopy(generation_config) if not using_model_generation_config: diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0818d90a57b..096db2d11bf 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -2047,7 +2047,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 0209c85b85f..567fa499208 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -236,6 +236,7 @@ from .import_utils import ( is_torchdistx_available, is_torchdynamo_available, is_torchdynamo_compiling, + is_torchdynamo_exporting, is_torchvision_available, is_torchvision_v2_available, is_training_run_on_sagemaker, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f4f3fa4ae46..1ac109c4773 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -866,6 +866,23 @@ def is_torchdynamo_compiling(): return False +def is_torchdynamo_exporting(): + if not is_torch_available(): + return False + + try: + import torch + + return torch.compiler.is_exporting() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_exporting() + except Exception: + return False + + def is_torch_tensorrt_fx_available(): if importlib.util.find_spec("torch_tensorrt") is None: return False diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3b43fddf548..7d3e1a6e622 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -47,7 +47,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import is_ipex_available +from transformers.utils import is_ipex_available, is_torchdynamo_exporting if is_torch_available(): @@ -87,6 +87,7 @@ if is_torch_available(): GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerationConfig, + GenerationMixin, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, LogitsProcessorList, @@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase): self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0) self.assertTrue(last_token_counts[8] > last_token_counts[3]) + def test_cache_dependant_input_preparation_exporting(self): + self.assertFalse( + is_torchdynamo_exporting() + ) # otherwise this test does not compare two different implementation + # Case 1 + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0] + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.range(0, 7, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + # Case 2 + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.range(0, 7, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + # Case 3 + input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.range(0, 7, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + # Case 4 + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.range(0, 7, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + global_rng = random.Random()