mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix cache_position
initialisation for generation with use_cache=False
(#30485)
* Fix cache_position init for generation * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix cache position update --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
54a2361a29
commit
4fda78c3f8
@ -667,7 +667,11 @@ class GenerationMixin:
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
||||
if (
|
||||
model_kwargs.get("use_cache", True)
|
||||
and "cache_position" in model_kwargs
|
||||
and model_kwargs["cache_position"] is not None
|
||||
):
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||
|
||||
return model_kwargs
|
||||
@ -1293,6 +1297,10 @@ class GenerationMixin:
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
if not model_kwargs.get("use_cache", True):
|
||||
model_kwargs["cache_position"] = None
|
||||
return model_kwargs
|
||||
|
||||
past_length = 0
|
||||
if "past_key_values" in model_kwargs:
|
||||
if isinstance(model_kwargs["past_key_values"], Cache):
|
||||
|
Loading…
Reference in New Issue
Block a user