diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4436b066107..b4f845b424e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1470,7 +1470,6 @@ class GenerationMixin: elif ( model_input_name == "inputs_embeds" and input_ids_length != inputs_tensor.shape[1] - and input_ids_length != 0 and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a953221eb4d..6c8a5e1285f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1786,12 +1786,12 @@ class GenerationTesterMixin: model.config.use_cache = True model.config.is_decoder = True batch_size = input_ids.shape[0] - max_cache_len = 30 + max_length = 30 # here we force to not stop at eos and go until max-length model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 generation_kwargs = { - "max_length": max_cache_len, + "max_length": max_length, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` } @@ -1810,11 +1810,11 @@ class GenerationTesterMixin: num_hidden_layers = text_config.num_hidden_layers inputs_embeds = model.get_input_embeddings()(input_ids) - max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) - # we should get `max_length` in shape, not `max_length - embeds_length` - cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + # we should get `max_length - 1` in shape, not `max_length - embeds_length`. + # -1 because the last generated token isn't yet in the cache. + cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim) self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)