mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
Gemma2: fix config initialization (cache_implementation
) (#33684)
This commit is contained in:
parent
d5bdac3db7
commit
238b13478d
@ -85,6 +85,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
size of the sliding window.
|
size of the sliding window.
|
||||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||||
|
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import Gemma2Model, Gemma2Config
|
>>> from transformers import Gemma2Model, Gemma2Config
|
||||||
@ -98,7 +99,6 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "gemma2"
|
model_type = "gemma2"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
cache_implementation = "hybrid"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -125,6 +125,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
sliding_window=4096,
|
sliding_window=4096,
|
||||||
final_logit_softcapping=30.0,
|
final_logit_softcapping=30.0,
|
||||||
attn_logit_softcapping=50.0,
|
attn_logit_softcapping=50.0,
|
||||||
|
cache_implementation="hybrid",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -153,3 +154,4 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.final_logit_softcapping = final_logit_softcapping
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
self.attn_logit_softcapping = attn_logit_softcapping
|
self.attn_logit_softcapping = attn_logit_softcapping
|
||||||
|
self.cache_implementation = cache_implementation
|
||||||
|
@ -117,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
size of the sliding window.
|
size of the sliding window.
|
||||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||||
|
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import Gemma2Model, Gemma2Config
|
>>> from transformers import Gemma2Model, Gemma2Config
|
||||||
@ -130,7 +131,6 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "gemma2"
|
model_type = "gemma2"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
cache_implementation = "hybrid"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -157,6 +157,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
sliding_window=4096,
|
sliding_window=4096,
|
||||||
final_logit_softcapping=30.0,
|
final_logit_softcapping=30.0,
|
||||||
attn_logit_softcapping=50.0,
|
attn_logit_softcapping=50.0,
|
||||||
|
cache_implementation="hybrid",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -185,6 +186,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.final_logit_softcapping = final_logit_softcapping
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
self.attn_logit_softcapping = attn_logit_softcapping
|
self.attn_logit_softcapping = attn_logit_softcapping
|
||||||
|
self.cache_implementation = cache_implementation
|
||||||
|
|
||||||
|
|
||||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||||
|
@ -44,7 +44,9 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"Qwen2Config": ["use_sliding_window"],
|
"Qwen2Config": ["use_sliding_window"],
|
||||||
"Qwen2MoeConfig": ["use_sliding_window"],
|
"Qwen2MoeConfig": ["use_sliding_window"],
|
||||||
"Qwen2VLConfig": ["use_sliding_window"],
|
"Qwen2VLConfig": ["use_sliding_window"],
|
||||||
"Gemma2Config": ["tie_word_embeddings"],
|
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
|
||||||
|
# generation configs (TODO joao)
|
||||||
|
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
|
||||||
# used to compute the property `self.chunk_length`
|
# used to compute the property `self.chunk_length`
|
||||||
"EncodecConfig": ["overlap"],
|
"EncodecConfig": ["overlap"],
|
||||||
# used to compute the property `self.layers_block_type`
|
# used to compute the property `self.layers_block_type`
|
||||||
|
Loading…
Reference in New Issue
Block a user