mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Gemma 2: support assisted generation (#32357)
This commit is contained in:
parent
5f1fcc299c
commit
ef177a5e1c
@ -171,6 +171,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
"Please pass in `min_length` into `.generate()` instead"
|
||||
)
|
||||
|
||||
# We need to roll back the cache in assisted generation, only DynamicCache is supported
|
||||
self.generation_config.cache_implementation = None
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
|
@ -1779,6 +1779,20 @@ class GenerationMixin:
|
||||
cache_name = "cache_params"
|
||||
else:
|
||||
cache_name = "past_key_values"
|
||||
|
||||
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
||||
# which is only supported in dynamic caches atm
|
||||
if (
|
||||
assistant_model is not None
|
||||
and generation_config.cache_implementation is not None
|
||||
and self._supports_default_dynamic_cache()
|
||||
):
|
||||
logger.warning_once(
|
||||
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
||||
f"'{generation_config.cache_implementation}'."
|
||||
)
|
||||
generation_config.cache_implementation = None
|
||||
|
||||
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
|
||||
raise ValueError(
|
||||
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
|
||||
|
@ -27,7 +27,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -591,10 +591,9 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = False
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = False
|
||||
_supports_static_cache = True
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -841,7 +840,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
Loading…
Reference in New Issue
Block a user