mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[generate] revert change in Aria: the maximum cache length must match max_length
(#36120)
* revert inputs_embeds len * Update test_utils.py * make fixup
This commit is contained in:
parent
b41591d847
commit
636ee57489
@ -1470,7 +1470,6 @@ class GenerationMixin:
|
||||
elif (
|
||||
model_input_name == "inputs_embeds"
|
||||
and input_ids_length != inputs_tensor.shape[1]
|
||||
and input_ids_length != 0
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
|
@ -1786,12 +1786,12 @@ class GenerationTesterMixin:
|
||||
model.config.use_cache = True
|
||||
model.config.is_decoder = True
|
||||
batch_size = input_ids.shape[0]
|
||||
max_cache_len = 30
|
||||
max_length = 30
|
||||
|
||||
# here we force to not stop at eos and go until max-length
|
||||
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||
generation_kwargs = {
|
||||
"max_length": max_cache_len,
|
||||
"max_length": max_length,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
@ -1810,11 +1810,11 @@ class GenerationTesterMixin:
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
|
||||
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
||||
|
||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
|
||||
# -1 because the last generated token isn't yet in the cache.
|
||||
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
|
||||
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
|
||||
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user