mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
Fixed Hybrid Cache Shape Initialization. (#32163)
* fixed hybrid cache init, added test * Fix Test Typo --------- Co-authored-by: Aaron Haag <aaron.haag@siemens.com>
This commit is contained in:
parent
e3d8285a84
commit
51ab25e293
@ -1813,7 +1813,9 @@ class GenerationMixin:
|
||||
)
|
||||
model_kwargs[cache_name] = self._get_cache(
|
||||
generation_config.cache_implementation,
|
||||
getattr(generation_config, "num_beams", 1) * batch_size,
|
||||
getattr(generation_config, "num_beams", 1)
|
||||
* getattr(generation_config, "num_return_sequences", 1)
|
||||
* batch_size,
|
||||
generation_config.max_length,
|
||||
model_kwargs,
|
||||
)
|
||||
|
@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
def test_hybrid_cache_n_sequences(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-9b",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
|
||||
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@require_auto_gptq
|
||||
def test_sink_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||
|
Loading…
Reference in New Issue
Block a user