mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
Fixed Hybrid Cache Shape Initialization. (#32163)
* fixed hybrid cache init, added test * Fix Test Typo --------- Co-authored-by: Aaron Haag <aaron.haag@siemens.com>
This commit is contained in:
parent
e3d8285a84
commit
51ab25e293
@ -1813,7 +1813,9 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
model_kwargs[cache_name] = self._get_cache(
|
model_kwargs[cache_name] = self._get_cache(
|
||||||
generation_config.cache_implementation,
|
generation_config.cache_implementation,
|
||||||
getattr(generation_config, "num_beams", 1) * batch_size,
|
getattr(generation_config, "num_beams", 1)
|
||||||
|
* getattr(generation_config, "num_return_sequences", 1)
|
||||||
|
* batch_size,
|
||||||
generation_config.max_length,
|
generation_config.max_length,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(decoded, expected_text)
|
self.assertListEqual(decoded, expected_text)
|
||||||
|
|
||||||
|
def test_hybrid_cache_n_sequences(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"google/gemma-2-9b",
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="eager",
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
gen_out = model.generate(
|
||||||
|
**inputs,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=20,
|
||||||
|
num_return_sequences=2,
|
||||||
|
)
|
||||||
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
|
expected_text = [
|
||||||
|
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||||
|
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||||
|
]
|
||||||
|
self.assertListEqual(decoded, expected_text)
|
||||||
|
|
||||||
@require_auto_gptq
|
@require_auto_gptq
|
||||||
def test_sink_cache_hard(self):
|
def test_sink_cache_hard(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||||
|
Loading…
Reference in New Issue
Block a user