mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Fixes DynamicCache export issues due to control flow and inplace modifications (#36652)
* Remove unnecessary masked_fill in deberta models * Enable some code when exporting but not compiling * add missing import * style * replace if by torch.cond * style * use numel * style * add unit tests * style * change empty value for dynamic cache * replace != [] by numel() * fix import issue * style
This commit is contained in:
parent
a165458901
commit
6f5dc9c82e
@ -79,10 +79,10 @@ class Cache:
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
if self.key_cache[layer_idx] != []:
|
||||
if self.key_cache[layer_idx].numel():
|
||||
device = self.key_cache[layer_idx].device
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
if self.value_cache[layer_idx] != []:
|
||||
if self.value_cache[layer_idx].numel():
|
||||
device = self.value_cache[layer_idx].device
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
|
||||
@ -433,12 +433,12 @@ class DynamicCache(Cache):
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# There may be skipped layers, fill them with empty lists
|
||||
for _ in range(len(self.key_cache), layer_idx):
|
||||
self.key_cache.append([])
|
||||
self.value_cache.append([])
|
||||
self.key_cache.append(torch.tensor([]))
|
||||
self.value_cache.append(torch.tensor([]))
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
elif (
|
||||
len(self.key_cache[layer_idx]) == 0
|
||||
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
|
||||
): # fills previously skipped layers; checking for tensor causes errors
|
||||
self.key_cache[layer_idx] = key_states
|
||||
self.value_cache[layer_idx] = value_states
|
||||
@ -454,7 +454,7 @@ class DynamicCache(Cache):
|
||||
is_empty_layer = (
|
||||
len(self.key_cache) == 0 # no cache in any layer
|
||||
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
||||
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
||||
or not self.key_cache[layer_idx].numel() # the layer has no cache
|
||||
)
|
||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
||||
return layer_seq_length
|
||||
@ -494,7 +494,7 @@ class DynamicCache(Cache):
|
||||
|
||||
self._seen_tokens = max_length
|
||||
for idx in range(len(self.key_cache)):
|
||||
if self.key_cache[idx] != []:
|
||||
if self.key_cache[idx].numel():
|
||||
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
||||
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
||||
|
||||
@ -516,8 +516,8 @@ class DynamicCache(Cache):
|
||||
`generation.utils`"""
|
||||
cache = cls()
|
||||
for idx in range(len(splits[0])):
|
||||
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
||||
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
|
||||
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()]
|
||||
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()]
|
||||
if key_cache != []:
|
||||
layer_keys = torch.cat(key_cache, dim=0)
|
||||
layer_values = torch.cat(value_cache, dim=0)
|
||||
|
@ -48,6 +48,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_hqq_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchdynamo_exporting,
|
||||
logging,
|
||||
)
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
@ -374,6 +375,102 @@ class GenerationMixin:
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
"""
|
||||
|
||||
def _cache_dependant_input_preparation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: Optional[torch.FloatTensor],
|
||||
cache_position: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
"""
|
||||
Generic cache-dependent input preparation
|
||||
The code is put in a separate function to allow granular unit testing
|
||||
as it needs a different implementation to be exportable.
|
||||
|
||||
If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
- Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
- Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
- Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
- Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
|
||||
The current implementation does not rely on ``self`` and could be
|
||||
a class method. It is left as a standard method to be easily rewritten.
|
||||
"""
|
||||
if is_torchdynamo_exporting():
|
||||
return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
return inputs_embeds, input_ids
|
||||
|
||||
def _cache_dependant_input_preparation_exporting(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: Optional[torch.FloatTensor],
|
||||
cache_position: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
"""
|
||||
This method implements method ``_cache_dependant_input_preparation``
|
||||
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
|
||||
The code is put in a separate function to allow granular unit testing.
|
||||
"""
|
||||
if inputs_embeds is None:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
else:
|
||||
# This is the code we need to implemented with torch.cond.
|
||||
# if input_ids.shape[1] == 0:
|
||||
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
# else:
|
||||
# if cache_position[-1] >= input_ids.shape[1]:
|
||||
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
# else:
|
||||
# if input_ids.shape[1] != cache_position.shape[0]:
|
||||
# input_ids = input_ids[:, cache_position]
|
||||
def branch_1(inputs_embeds, cache_position):
|
||||
return inputs_embeds[:, -cache_position.shape[0] :]
|
||||
|
||||
def branch_2(input_ids, cache_position):
|
||||
return input_ids[:, -cache_position.shape[0] :]
|
||||
|
||||
def branch_3(input_ids, cache_position):
|
||||
return input_ids[:, cache_position]
|
||||
|
||||
inputs_embeds, input_ids = torch.cond(
|
||||
input_ids.shape[1] == 0,
|
||||
(
|
||||
lambda input_ids, inputs_embeds, cache_position: (
|
||||
branch_1(inputs_embeds, cache_position),
|
||||
input_ids,
|
||||
)
|
||||
),
|
||||
(
|
||||
lambda input_ids, inputs_embeds, cache_position: (
|
||||
inputs_embeds,
|
||||
torch.cond(
|
||||
cache_position[-1] >= input_ids.shape[1],
|
||||
branch_2,
|
||||
lambda input_ids, cache_position: (
|
||||
torch.cond(
|
||||
input_ids.shape[1] != cache_position.shape[0],
|
||||
branch_3,
|
||||
(lambda input_ids, cache_position: input_ids),
|
||||
[input_ids, cache_position],
|
||||
)
|
||||
),
|
||||
[input_ids, cache_position],
|
||||
),
|
||||
)
|
||||
),
|
||||
[input_ids, inputs_embeds, cache_position],
|
||||
)
|
||||
return inputs_embeds, input_ids
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -404,23 +501,11 @@ class GenerationMixin:
|
||||
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
# 2. Generic cache-dependent input preparation
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
model_inputs["past_key_values"] = past_key_values
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
|
||||
# 3. Prepare base model inputs
|
||||
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
@ -1590,6 +1675,8 @@ class GenerationMixin:
|
||||
generation_config = self.generation_config
|
||||
using_model_generation_config = True
|
||||
|
||||
# `torch.export.export` usually raises an exception if it is called
|
||||
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
|
||||
if not using_model_generation_config:
|
||||
|
@ -2047,7 +2047,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
|
@ -236,6 +236,7 @@ from .import_utils import (
|
||||
is_torchdistx_available,
|
||||
is_torchdynamo_available,
|
||||
is_torchdynamo_compiling,
|
||||
is_torchdynamo_exporting,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_training_run_on_sagemaker,
|
||||
|
@ -866,6 +866,23 @@ def is_torchdynamo_compiling():
|
||||
return False
|
||||
|
||||
|
||||
def is_torchdynamo_exporting():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch.compiler.is_exporting()
|
||||
except Exception:
|
||||
try:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
|
||||
return dynamo.is_exporting()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_tensorrt_fx_available():
|
||||
if importlib.util.find_spec("torch_tensorrt") is None:
|
||||
return False
|
||||
|
@ -47,7 +47,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_ipex_available
|
||||
from transformers.utils import is_ipex_available, is_torchdynamo_exporting
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -87,6 +87,7 @@ if is_torch_available():
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
LogitsProcessorList,
|
||||
@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
|
||||
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
||||
|
||||
def test_cache_dependant_input_preparation_exporting(self):
|
||||
self.assertFalse(
|
||||
is_torchdynamo_exporting()
|
||||
) # otherwise this test does not compare two different implementation
|
||||
# Case 1
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
|
||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 2
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 3
|
||||
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
|
||||
inputs_embeds = None
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 4
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||
inputs_embeds = None
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user