Fixes DynamicCache export issues due to control flow and inplace modifications (#36652)

* Remove unnecessary masked_fill in deberta models

* Enable some code when exporting but not compiling

* add missing import

* style

* replace if by torch.cond

* style

* use numel

* style

* add unit tests

* style

* change empty value for dynamic cache

* replace != [] by numel()

* fix import issue

* style
This commit is contained in:
Xavier Dupré 2025-04-02 13:04:40 +02:00 committed by GitHub
parent a165458901
commit 6f5dc9c82e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 180 additions and 26 deletions

View File

@ -79,10 +79,10 @@ class Cache:
def reorder_cache(self, beam_idx: torch.LongTensor): def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices.""" """Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] != []: if self.key_cache[layer_idx].numel():
device = self.key_cache[layer_idx].device device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []: if self.value_cache[layer_idx].numel():
device = self.value_cache[layer_idx].device device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
@ -433,12 +433,12 @@ class DynamicCache(Cache):
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists # There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx): for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([]) self.key_cache.append(torch.tensor([]))
self.value_cache.append([]) self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states) self.key_cache.append(key_states)
self.value_cache.append(value_states) self.value_cache.append(value_states)
elif ( elif (
len(self.key_cache[layer_idx]) == 0 not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
): # fills previously skipped layers; checking for tensor causes errors ): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states self.value_cache[layer_idx] = value_states
@ -454,7 +454,7 @@ class DynamicCache(Cache):
is_empty_layer = ( is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache or not self.key_cache[layer_idx].numel() # the layer has no cache
) )
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length return layer_seq_length
@ -494,7 +494,7 @@ class DynamicCache(Cache):
self._seen_tokens = max_length self._seen_tokens = max_length
for idx in range(len(self.key_cache)): for idx in range(len(self.key_cache)):
if self.key_cache[idx] != []: if self.key_cache[idx].numel():
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
@ -516,8 +516,8 @@ class DynamicCache(Cache):
`generation.utils`""" `generation.utils`"""
cache = cls() cache = cls()
for idx in range(len(splits[0])): for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()]
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []] value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()]
if key_cache != []: if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0) layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0) layer_values = torch.cat(value_cache, dim=0)

View File

@ -48,6 +48,7 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_hqq_available, is_hqq_available,
is_optimum_quanto_available, is_optimum_quanto_available,
is_torchdynamo_exporting,
logging, logging,
) )
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
@ -374,6 +375,102 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
""" """
def _cache_dependant_input_preparation(
self,
input_ids: torch.LongTensor,
inputs_embeds: Optional[torch.FloatTensor],
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""
Generic cache-dependent input preparation
The code is put in a separate function to allow granular unit testing
as it needs a different implementation to be exportable.
If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- Exception 1: when passing input_embeds, input_ids may be missing entries
- Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
- Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
generate the first token for each sequence. Later use the generated Input ids for continuation.
The current implementation does not rely on ``self`` and could be
a class method. It is left as a standard method to be easily rewritten.
"""
if is_torchdynamo_exporting():
return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
return inputs_embeds, input_ids
def _cache_dependant_input_preparation_exporting(
self,
input_ids: torch.LongTensor,
inputs_embeds: Optional[torch.FloatTensor],
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""
This method implements method ``_cache_dependant_input_preparation``
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
The code is put in a separate function to allow granular unit testing.
"""
if inputs_embeds is None:
input_ids = input_ids[:, cache_position]
else:
# This is the code we need to implemented with torch.cond.
# if input_ids.shape[1] == 0:
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
# else:
# if cache_position[-1] >= input_ids.shape[1]:
# input_ids = input_ids[:, -cache_position.shape[0] :]
# else:
# if input_ids.shape[1] != cache_position.shape[0]:
# input_ids = input_ids[:, cache_position]
def branch_1(inputs_embeds, cache_position):
return inputs_embeds[:, -cache_position.shape[0] :]
def branch_2(input_ids, cache_position):
return input_ids[:, -cache_position.shape[0] :]
def branch_3(input_ids, cache_position):
return input_ids[:, cache_position]
inputs_embeds, input_ids = torch.cond(
input_ids.shape[1] == 0,
(
lambda input_ids, inputs_embeds, cache_position: (
branch_1(inputs_embeds, cache_position),
input_ids,
)
),
(
lambda input_ids, inputs_embeds, cache_position: (
inputs_embeds,
torch.cond(
cache_position[-1] >= input_ids.shape[1],
branch_2,
lambda input_ids, cache_position: (
torch.cond(
input_ids.shape[1] != cache_position.shape[0],
branch_3,
(lambda input_ids, cache_position: input_ids),
[input_ids, cache_position],
)
),
[input_ids, cache_position],
),
)
),
[input_ids, inputs_embeds, cache_position],
)
return inputs_embeds, input_ids
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@ -404,23 +501,11 @@ class GenerationMixin:
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
# 2. Generic cache-dependent input preparation # 2. Generic cache-dependent input preparation
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None: if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 inputs_embeds, input_ids = self._cache_dependant_input_preparation(
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] input_ids, inputs_embeds, cache_position
elif ( )
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# 3. Prepare base model inputs # 3. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
@ -1590,6 +1675,8 @@ class GenerationMixin:
generation_config = self.generation_config generation_config = self.generation_config
using_model_generation_config = True using_model_generation_config = True
# `torch.export.export` usually raises an exception if it is called
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
if not using_model_generation_config: if not using_model_generation_config:

View File

@ -2047,7 +2047,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,

View File

@ -236,6 +236,7 @@ from .import_utils import (
is_torchdistx_available, is_torchdistx_available,
is_torchdynamo_available, is_torchdynamo_available,
is_torchdynamo_compiling, is_torchdynamo_compiling,
is_torchdynamo_exporting,
is_torchvision_available, is_torchvision_available,
is_torchvision_v2_available, is_torchvision_v2_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,

View File

@ -866,6 +866,23 @@ def is_torchdynamo_compiling():
return False return False
def is_torchdynamo_exporting():
if not is_torch_available():
return False
try:
import torch
return torch.compiler.is_exporting()
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_exporting()
except Exception:
return False
def is_torch_tensorrt_fx_available(): def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None: if importlib.util.find_spec("torch_tensorrt") is None:
return False return False

View File

@ -47,7 +47,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import is_ipex_available from transformers.utils import is_ipex_available, is_torchdynamo_exporting
if is_torch_available(): if is_torch_available():
@ -87,6 +87,7 @@ if is_torch_available():
GenerateDecoderOnlyOutput, GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput, GenerateEncoderDecoderOutput,
GenerationConfig, GenerationConfig,
GenerationMixin,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput,
LogitsProcessorList, LogitsProcessorList,
@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0) self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
self.assertTrue(last_token_counts[8] > last_token_counts[3]) self.assertTrue(last_token_counts[8] > last_token_counts[3])
def test_cache_dependant_input_preparation_exporting(self):
self.assertFalse(
is_torchdynamo_exporting()
) # otherwise this test does not compare two different implementation
# Case 1
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 2
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 3
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 4
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
global_rng = random.Random() global_rng = random.Random()