Fix cache_position initialisation for generation with use_cache=False (#30485)

* Fix cache_position init for generation

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix cache position update

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Zhakshylyk Nurlanov 2024-05-07 11:13:11 +02:00 committed by GitHub
parent 54a2361a29
commit 4fda78c3f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -667,7 +667,11 @@ class GenerationMixin:
dim=-1,
)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
if (
model_kwargs.get("use_cache", True)
and "cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
@ -1293,6 +1297,10 @@ class GenerationMixin:
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs
past_length = 0
if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache):