Gemma 2: support assisted generation (#32357)

This commit is contained in:
Joao Gante 2024-07-31 16:04:48 +01:00 committed by GitHub
parent 5f1fcc299c
commit ef177a5e1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 4 deletions

View File

@ -171,6 +171,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
"Please pass in `min_length` into `.generate()` instead"
)
# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.

View File

@ -1779,6 +1779,20 @@ class GenerationMixin:
cache_name = "cache_params"
else:
cache_name = "past_key_values"
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if (
assistant_model is not None
and generation_config.cache_implementation is not None
and self._supports_default_dynamic_cache()
):
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "

View File

@ -27,7 +27,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...cache_utils import Cache, HybridCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -591,10 +591,9 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = False
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
_is_stateful = True
def _init_weights(self, module):
std = self.config.initializer_range
@ -841,7 +840,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None:
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]