diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 6fe71a44655..1ab5ea527e4 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -171,6 +171,9 @@ class AssistedCandidateGenerator(CandidateGenerator): "Please pass in `min_length` into `.generate()` instead" ) + # We need to roll back the cache in assisted generation, only DynamicCache is supported + self.generation_config.cache_implementation = None + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8918487696e..a5e53a294ba 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1779,6 +1779,20 @@ class GenerationMixin: cache_name = "cache_params" else: cache_name = "past_key_values" + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if ( + assistant_model is not None + and generation_config.cache_implementation is not None + and self._supports_default_dynamic_cache() + ): + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling(): raise ValueError( "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 53e08df4ce5..ce2ee1ef1a6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -27,7 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, HybridCache from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -591,10 +591,9 @@ class Gemma2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = False + _supports_cache_class = True _supports_quantized_cache = False _supports_static_cache = True - _is_stateful = True def _init_weights(self, module): std = self.config.initializer_range @@ -841,7 +840,7 @@ class Gemma2Model(Gemma2PreTrainedModel): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if past_key_values is not None: + if isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_length() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]