[generate] remove cache v4.47 deprecations (#36212)

This commit is contained in:
Joao Gante 2025-02-17 13:55:03 +00:00 committed by GitHub
parent 936aeb70ab
commit dad513e0c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 28 deletions

View File

@ -363,8 +363,7 @@ class DynamicCache(Cache):
```
"""
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
def __init__(self) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = []
@ -466,10 +465,7 @@ class DynamicCache(Cache):
return legacy_cache
@classmethod
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "DynamicCache":
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
@ -495,10 +491,7 @@ class DynamicCache(Cache):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
@ -511,8 +504,7 @@ class DynamicCache(Cache):
return out
@classmethod
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
@ -1527,10 +1519,7 @@ class EncoderDecoderCache(Cache):
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
) -> "List[EncoderDecoderCache]":
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
@ -1543,10 +1532,7 @@ class EncoderDecoderCache(Cache):
return out
@classmethod
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_batch_splits(
cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
) -> "EncoderDecoderCache":
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()

View File

@ -4520,7 +4520,7 @@ def _ranking_fast(
return selected_idx
def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None):
def _split(data, full_batch_size: int, split_size: int = None):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
@ -4538,7 +4538,7 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
@ -4591,11 +4591,9 @@ def _split_model_inputs(
keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
num_hidden_layers = config.get_text_config().num_hidden_layers
# we split the tensors and tuples of tensors
data_split_list = [
{k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys}
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
for i in range(full_batch_size // split_size)
]
# bool values are the same and replicated for each split
@ -4632,7 +4630,6 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
# Infer the class from the first object in the list
model_output_cls = type(model_outputs[0])
num_hidden_layers = config.get_text_config().num_hidden_layers
# Ensure all objects are of the same type
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
@ -4649,9 +4646,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
return EncoderDecoderCache.from_batch_splits(data)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):

View File

@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import PhimoeConfig, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import (
is_flaky,
require_torch,
slow,
torch_device,
@ -449,6 +450,7 @@ class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@parameterized.expand([("longrope",)])
@is_flaky() # TODO (joao): unify rope tests in the mixin
def test_model_rope_scaling_short_long_factor(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
n_factors = config.hidden_size // config.num_key_value_heads // 2

View File

@ -27,6 +27,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
@ -347,6 +348,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_generate_compile_fullgraph(self):
pass
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search()
@require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase):