mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Gemma2: fix config initialization (cache_implementation
) (#33684)
This commit is contained in:
parent
d5bdac3db7
commit
238b13478d
@ -85,6 +85,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
@ -98,7 +99,6 @@ class Gemma2Config(PretrainedConfig):
|
||||
|
||||
model_type = "gemma2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
cache_implementation = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -125,6 +125,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
cache_implementation="hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -153,3 +154,4 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
|
@ -117,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
@ -130,7 +131,6 @@ class Gemma2Config(PretrainedConfig):
|
||||
|
||||
model_type = "gemma2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
cache_implementation = "hybrid"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -157,6 +157,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
cache_implementation="hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -185,6 +186,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
|
||||
|
||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||
|
@ -44,7 +44,9 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"Qwen2Config": ["use_sliding_window"],
|
||||
"Qwen2MoeConfig": ["use_sliding_window"],
|
||||
"Qwen2VLConfig": ["use_sliding_window"],
|
||||
"Gemma2Config": ["tie_word_embeddings"],
|
||||
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
|
||||
# generation configs (TODO joao)
|
||||
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
|
||||
# used to compute the property `self.chunk_length`
|
||||
"EncodecConfig": ["overlap"],
|
||||
# used to compute the property `self.layers_block_type`
|
||||
|
Loading…
Reference in New Issue
Block a user