mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[generate] remove cache v4.47 deprecations (#36212)
This commit is contained in:
parent
936aeb70ab
commit
dad513e0c2
@ -363,8 +363,7 @@ class DynamicCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
@ -466,10 +465,7 @@ class DynamicCache(Cache):
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
|
||||
) -> "DynamicCache":
|
||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls()
|
||||
@ -495,10 +491,7 @@ class DynamicCache(Cache):
|
||||
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
||||
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
||||
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def batch_split(
|
||||
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
|
||||
) -> List["DynamicCache"]:
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
out = []
|
||||
@ -511,8 +504,7 @@ class DynamicCache(Cache):
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
|
||||
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
cache = cls()
|
||||
@ -1527,10 +1519,7 @@ class EncoderDecoderCache(Cache):
|
||||
self.check_dynamic_cache(self.crop.__name__)
|
||||
self.self_attention_cache.crop(maximum_length)
|
||||
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def batch_split(
|
||||
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
|
||||
) -> "List[EncoderDecoderCache]":
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
self.check_dynamic_cache(self.batch_split.__name__)
|
||||
@ -1543,10 +1532,7 @@ class EncoderDecoderCache(Cache):
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def from_batch_splits(
|
||||
cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
|
||||
) -> "EncoderDecoderCache":
|
||||
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
self_attention_cache = DynamicCache()
|
||||
|
@ -4520,7 +4520,7 @@ def _ranking_fast(
|
||||
return selected_idx
|
||||
|
||||
|
||||
def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None):
|
||||
def _split(data, full_batch_size: int, split_size: int = None):
|
||||
"""
|
||||
Takes care of three cases:
|
||||
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
|
||||
@ -4538,7 +4538,7 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
|
||||
elif isinstance(data, DynamicCache) or (
|
||||
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||
):
|
||||
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||
return data.batch_split(full_batch_size, split_size)
|
||||
elif isinstance(data, tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0], tuple):
|
||||
@ -4591,11 +4591,9 @@ def _split_model_inputs(
|
||||
keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"]
|
||||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
|
||||
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
|
||||
# we split the tensors and tuples of tensors
|
||||
data_split_list = [
|
||||
{k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys}
|
||||
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
|
||||
for i in range(full_batch_size // split_size)
|
||||
]
|
||||
# bool values are the same and replicated for each split
|
||||
@ -4632,7 +4630,6 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
|
||||
|
||||
# Infer the class from the first object in the list
|
||||
model_output_cls = type(model_outputs[0])
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
|
||||
# Ensure all objects are of the same type
|
||||
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
||||
@ -4649,9 +4646,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
|
||||
return torch.cat(data, dim=0)
|
||||
# New cache format
|
||||
elif isinstance(data[0], DynamicCache):
|
||||
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
return DynamicCache.from_batch_splits(data)
|
||||
elif isinstance(data[0], EncoderDecoderCache):
|
||||
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
return EncoderDecoderCache.from_batch_splits(data)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
|
@ -22,6 +22,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import PhimoeConfig, StaticCache, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -449,6 +450,7 @@ class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
@parameterized.expand([("longrope",)])
|
||||
@is_flaky() # TODO (joao): unify rope tests in the mixin
|
||||
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
||||
|
@ -27,6 +27,7 @@ from transformers import (
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@ -347,6 +348,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
super().test_prompt_lookup_decoding_matches_greedy_search()
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5_VLIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user