mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fixes DynamicCache export issues due to control flow and inplace modifications (#36652)
* Remove unnecessary masked_fill in deberta models * Enable some code when exporting but not compiling * add missing import * style * replace if by torch.cond * style * use numel * style * add unit tests * style * change empty value for dynamic cache * replace != [] by numel() * fix import issue * style
This commit is contained in:
parent
a165458901
commit
6f5dc9c82e
@ -79,10 +79,10 @@ class Cache:
|
|||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
for layer_idx in range(len(self.key_cache)):
|
for layer_idx in range(len(self.key_cache)):
|
||||||
if self.key_cache[layer_idx] != []:
|
if self.key_cache[layer_idx].numel():
|
||||||
device = self.key_cache[layer_idx].device
|
device = self.key_cache[layer_idx].device
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
if self.value_cache[layer_idx] != []:
|
if self.value_cache[layer_idx].numel():
|
||||||
device = self.value_cache[layer_idx].device
|
device = self.value_cache[layer_idx].device
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
|
||||||
@ -433,12 +433,12 @@ class DynamicCache(Cache):
|
|||||||
if len(self.key_cache) <= layer_idx:
|
if len(self.key_cache) <= layer_idx:
|
||||||
# There may be skipped layers, fill them with empty lists
|
# There may be skipped layers, fill them with empty lists
|
||||||
for _ in range(len(self.key_cache), layer_idx):
|
for _ in range(len(self.key_cache), layer_idx):
|
||||||
self.key_cache.append([])
|
self.key_cache.append(torch.tensor([]))
|
||||||
self.value_cache.append([])
|
self.value_cache.append(torch.tensor([]))
|
||||||
self.key_cache.append(key_states)
|
self.key_cache.append(key_states)
|
||||||
self.value_cache.append(value_states)
|
self.value_cache.append(value_states)
|
||||||
elif (
|
elif (
|
||||||
len(self.key_cache[layer_idx]) == 0
|
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
|
||||||
): # fills previously skipped layers; checking for tensor causes errors
|
): # fills previously skipped layers; checking for tensor causes errors
|
||||||
self.key_cache[layer_idx] = key_states
|
self.key_cache[layer_idx] = key_states
|
||||||
self.value_cache[layer_idx] = value_states
|
self.value_cache[layer_idx] = value_states
|
||||||
@ -454,7 +454,7 @@ class DynamicCache(Cache):
|
|||||||
is_empty_layer = (
|
is_empty_layer = (
|
||||||
len(self.key_cache) == 0 # no cache in any layer
|
len(self.key_cache) == 0 # no cache in any layer
|
||||||
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
||||||
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
or not self.key_cache[layer_idx].numel() # the layer has no cache
|
||||||
)
|
)
|
||||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
||||||
return layer_seq_length
|
return layer_seq_length
|
||||||
@ -494,7 +494,7 @@ class DynamicCache(Cache):
|
|||||||
|
|
||||||
self._seen_tokens = max_length
|
self._seen_tokens = max_length
|
||||||
for idx in range(len(self.key_cache)):
|
for idx in range(len(self.key_cache)):
|
||||||
if self.key_cache[idx] != []:
|
if self.key_cache[idx].numel():
|
||||||
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
||||||
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
||||||
|
|
||||||
@ -516,8 +516,8 @@ class DynamicCache(Cache):
|
|||||||
`generation.utils`"""
|
`generation.utils`"""
|
||||||
cache = cls()
|
cache = cls()
|
||||||
for idx in range(len(splits[0])):
|
for idx in range(len(splits[0])):
|
||||||
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()]
|
||||||
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
|
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()]
|
||||||
if key_cache != []:
|
if key_cache != []:
|
||||||
layer_keys = torch.cat(key_cache, dim=0)
|
layer_keys = torch.cat(key_cache, dim=0)
|
||||||
layer_values = torch.cat(value_cache, dim=0)
|
layer_values = torch.cat(value_cache, dim=0)
|
||||||
|
@ -48,6 +48,7 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
is_optimum_quanto_available,
|
is_optimum_quanto_available,
|
||||||
|
is_torchdynamo_exporting,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||||
@ -374,6 +375,102 @@ class GenerationMixin:
|
|||||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _cache_dependant_input_preparation(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor],
|
||||||
|
cache_position: Optional[torch.LongTensor],
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
|
"""
|
||||||
|
Generic cache-dependent input preparation
|
||||||
|
The code is put in a separate function to allow granular unit testing
|
||||||
|
as it needs a different implementation to be exportable.
|
||||||
|
|
||||||
|
If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
- Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
- Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
- Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||||
|
- Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||||
|
generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||||
|
|
||||||
|
The current implementation does not rely on ``self`` and could be
|
||||||
|
a class method. It is left as a standard method to be easily rewritten.
|
||||||
|
"""
|
||||||
|
if is_torchdynamo_exporting():
|
||||||
|
return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
|
||||||
|
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||||
|
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||||
|
elif (
|
||||||
|
inputs_embeds is not None # Exception 1
|
||||||
|
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||||
|
):
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
return inputs_embeds, input_ids
|
||||||
|
|
||||||
|
def _cache_dependant_input_preparation_exporting(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor],
|
||||||
|
cache_position: Optional[torch.LongTensor],
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
|
"""
|
||||||
|
This method implements method ``_cache_dependant_input_preparation``
|
||||||
|
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
|
||||||
|
The code is put in a separate function to allow granular unit testing.
|
||||||
|
"""
|
||||||
|
if inputs_embeds is None:
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
else:
|
||||||
|
# This is the code we need to implemented with torch.cond.
|
||||||
|
# if input_ids.shape[1] == 0:
|
||||||
|
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||||
|
# else:
|
||||||
|
# if cache_position[-1] >= input_ids.shape[1]:
|
||||||
|
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
# else:
|
||||||
|
# if input_ids.shape[1] != cache_position.shape[0]:
|
||||||
|
# input_ids = input_ids[:, cache_position]
|
||||||
|
def branch_1(inputs_embeds, cache_position):
|
||||||
|
return inputs_embeds[:, -cache_position.shape[0] :]
|
||||||
|
|
||||||
|
def branch_2(input_ids, cache_position):
|
||||||
|
return input_ids[:, -cache_position.shape[0] :]
|
||||||
|
|
||||||
|
def branch_3(input_ids, cache_position):
|
||||||
|
return input_ids[:, cache_position]
|
||||||
|
|
||||||
|
inputs_embeds, input_ids = torch.cond(
|
||||||
|
input_ids.shape[1] == 0,
|
||||||
|
(
|
||||||
|
lambda input_ids, inputs_embeds, cache_position: (
|
||||||
|
branch_1(inputs_embeds, cache_position),
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
lambda input_ids, inputs_embeds, cache_position: (
|
||||||
|
inputs_embeds,
|
||||||
|
torch.cond(
|
||||||
|
cache_position[-1] >= input_ids.shape[1],
|
||||||
|
branch_2,
|
||||||
|
lambda input_ids, cache_position: (
|
||||||
|
torch.cond(
|
||||||
|
input_ids.shape[1] != cache_position.shape[0],
|
||||||
|
branch_3,
|
||||||
|
(lambda input_ids, cache_position: input_ids),
|
||||||
|
[input_ids, cache_position],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
[input_ids, cache_position],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
[input_ids, inputs_embeds, cache_position],
|
||||||
|
)
|
||||||
|
return inputs_embeds, input_ids
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@ -404,23 +501,11 @@ class GenerationMixin:
|
|||||||
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
# 2. Generic cache-dependent input preparation
|
# 2. Generic cache-dependent input preparation
|
||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
|
||||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
|
||||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
|
||||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
|
||||||
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
|
||||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
model_inputs["past_key_values"] = past_key_values
|
model_inputs["past_key_values"] = past_key_values
|
||||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
|
||||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
input_ids, inputs_embeds, cache_position
|
||||||
elif (
|
)
|
||||||
inputs_embeds is not None # Exception 1
|
|
||||||
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
|
||||||
):
|
|
||||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
||||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
|
||||||
input_ids = input_ids[:, cache_position]
|
|
||||||
|
|
||||||
# 3. Prepare base model inputs
|
# 3. Prepare base model inputs
|
||||||
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
@ -1590,6 +1675,8 @@ class GenerationMixin:
|
|||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
using_model_generation_config = True
|
using_model_generation_config = True
|
||||||
|
|
||||||
|
# `torch.export.export` usually raises an exception if it is called
|
||||||
|
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
|
||||||
if not using_model_generation_config:
|
if not using_model_generation_config:
|
||||||
|
@ -2047,7 +2047,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|||||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
@ -236,6 +236,7 @@ from .import_utils import (
|
|||||||
is_torchdistx_available,
|
is_torchdistx_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
|
is_torchdynamo_exporting,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_torchvision_v2_available,
|
is_torchvision_v2_available,
|
||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
|
@ -866,6 +866,23 @@ def is_torchdynamo_compiling():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchdynamo_exporting():
|
||||||
|
if not is_torch_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
return torch.compiler.is_exporting()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import torch._dynamo as dynamo # noqa: F401
|
||||||
|
|
||||||
|
return dynamo.is_exporting()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tensorrt_fx_available():
|
def is_torch_tensorrt_fx_available():
|
||||||
if importlib.util.find_spec("torch_tensorrt") is None:
|
if importlib.util.find_spec("torch_tensorrt") is None:
|
||||||
return False
|
return False
|
||||||
|
@ -47,7 +47,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_ipex_available
|
from transformers.utils import is_ipex_available, is_torchdynamo_exporting
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -87,6 +87,7 @@ if is_torch_available():
|
|||||||
GenerateDecoderOnlyOutput,
|
GenerateDecoderOnlyOutput,
|
||||||
GenerateEncoderDecoderOutput,
|
GenerateEncoderDecoderOutput,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
|
GenerationMixin,
|
||||||
GreedySearchDecoderOnlyOutput,
|
GreedySearchDecoderOnlyOutput,
|
||||||
GreedySearchEncoderDecoderOutput,
|
GreedySearchEncoderDecoderOutput,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
|
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
|
||||||
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
||||||
|
|
||||||
|
def test_cache_dependant_input_preparation_exporting(self):
|
||||||
|
self.assertFalse(
|
||||||
|
is_torchdynamo_exporting()
|
||||||
|
) # otherwise this test does not compare two different implementation
|
||||||
|
# Case 1
|
||||||
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
|
||||||
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||||
|
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||||
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
|
input_ids, inputs_embeds, cache_position
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(eager1, export1)
|
||||||
|
torch.testing.assert_close(eager2, export2)
|
||||||
|
|
||||||
|
# Case 2
|
||||||
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||||
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||||
|
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||||
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
|
input_ids, inputs_embeds, cache_position
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(eager1, export1)
|
||||||
|
torch.testing.assert_close(eager2, export2)
|
||||||
|
|
||||||
|
# Case 3
|
||||||
|
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
|
||||||
|
inputs_embeds = None
|
||||||
|
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||||
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
|
input_ids, inputs_embeds, cache_position
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(eager1, export1)
|
||||||
|
torch.testing.assert_close(eager2, export2)
|
||||||
|
|
||||||
|
# Case 4
|
||||||
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||||
|
inputs_embeds = None
|
||||||
|
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||||
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
|
input_ids, inputs_embeds, cache_position
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(eager1, export1)
|
||||||
|
torch.testing.assert_close(eager2, export2)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user