Gemma2: fix config initialization (cache_implementation) (#33684)

This commit is contained in:
Joao Gante 2024-09-24 18:22:00 +01:00 committed by GitHub
parent d5bdac3db7
commit 238b13478d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 3 deletions

View File

@ -85,6 +85,7 @@ class Gemma2Config(PretrainedConfig):
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
```python
>>> from transformers import Gemma2Model, Gemma2Config
@ -98,7 +99,6 @@ class Gemma2Config(PretrainedConfig):
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"
def __init__(
self,
@ -125,6 +125,7 @@ class Gemma2Config(PretrainedConfig):
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs,
):
super().__init__(
@ -153,3 +154,4 @@ class Gemma2Config(PretrainedConfig):
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation

View File

@ -117,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
```python
>>> from transformers import Gemma2Model, Gemma2Config
@ -130,7 +131,6 @@ class Gemma2Config(PretrainedConfig):
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
cache_implementation = "hybrid"
def __init__(
self,
@ -157,6 +157,7 @@ class Gemma2Config(PretrainedConfig):
sliding_window=4096,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs,
):
super().__init__(
@ -185,6 +186,7 @@ class Gemma2Config(PretrainedConfig):
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
class Gemma2RMSNorm(GemmaRMSNorm):

View File

@ -44,7 +44,9 @@ SPECIAL_CASES_TO_ALLOW = {
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Qwen2VLConfig": ["use_sliding_window"],
"Gemma2Config": ["tie_word_embeddings"],
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
# generation configs (TODO joao)
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`