Fixed Hybrid Cache Shape Initialization. (#32163)

* fixed hybrid cache init, added test

* Fix Test Typo

---------

Co-authored-by: Aaron Haag <aaron.haag@siemens.com>
This commit is contained in:
OsamaS99 2024-08-01 14:57:42 +02:00 committed by GitHub
parent e3d8285a84
commit 51ab25e293
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 1 deletions

View File

@ -1813,7 +1813,9 @@ class GenerationMixin:
)
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
getattr(generation_config, "num_beams", 1)
* getattr(generation_config, "num_return_sequences", 1)
* batch_size,
generation_config.max_length,
model_kwargs,
)

View File

@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase):
]
self.assertListEqual(decoded, expected_text)
def test_hybrid_cache_n_sequences(self):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-9b",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
gen_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=20,
num_return_sequences=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
]
self.assertListEqual(decoded, expected_text)
@require_auto_gptq
def test_sink_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")