diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e7736537e0c..385d68cfbef 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1813,7 +1813,9 @@ class GenerationMixin: ) model_kwargs[cache_name] = self._get_cache( generation_config.cache_implementation, - getattr(generation_config, "num_beams", 1) * batch_size, + getattr(generation_config, "num_beams", 1) + * getattr(generation_config, "num_return_sequences", 1) + * batch_size, generation_config.max_length, model_kwargs, ) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 2729a2989ab..e8d120216a0 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase): ] self.assertListEqual(decoded, expected_text) + def test_hybrid_cache_n_sequences(self): + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2-9b", + device_map="auto", + torch_dtype=torch.bfloat16, + attn_implementation="eager", + ) + + inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device) + + gen_out = model.generate( + **inputs, + do_sample=False, + max_new_tokens=20, + num_return_sequences=2, + ) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + expected_text = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + ] + self.assertListEqual(decoded, expected_text) + @require_auto_gptq def test_sink_cache_hard(self): tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")