[generate] revert change in Aria: the maximum cache length must match max_length (#36120)

* revert inputs_embeds len

* Update test_utils.py

* make fixup
This commit is contained in:
Joao Gante 2025-02-13 14:36:33 +00:00 committed by GitHub
parent b41591d847
commit 636ee57489
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -1470,7 +1470,6 @@ class GenerationMixin:
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and input_ids_length != 0
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]

View File

@ -1786,12 +1786,12 @@ class GenerationTesterMixin:
model.config.use_cache = True
model.config.is_decoder = True
batch_size = input_ids.shape[0]
max_cache_len = 30
max_length = 30
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
generation_kwargs = {
"max_length": max_cache_len,
"max_length": max_length,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
@ -1810,11 +1810,11 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers
inputs_embeds = model.get_input_embeddings()(input_ids)
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
# we should get `max_length` in shape, not `max_length - embeds_length`
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
# -1 because the last generated token isn't yet in the cache.
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)