mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[Gemma
] final fixes to the modeling (#29729)
* gelu_pytorch_tanh * Force config.hidden_act to be approx gelu * Gemma bug fixes * force_use_exact_gelu * Update configuration_gemma.py * Update modeling_gemma.py * update * update for simpler handling * nit * nit * fixpup * update * also update the jax modeling! * add `"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),` * fixup * fix order * act vs act_fn --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
parent
229ac72b1e
commit
8e2fc52ea3
@ -78,6 +78,7 @@ ACT2FN = {
|
||||
"swish": nn.swish,
|
||||
"gelu_new": partial(nn.gelu, approximate=True),
|
||||
"quick_gelu": quick_gelu,
|
||||
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
|
||||
}
|
||||
|
||||
|
||||
|
@ -57,8 +57,11 @@ class GemmaConfig(PretrainedConfig):
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
@ -108,7 +111,8 @@ class GemmaConfig(PretrainedConfig):
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu",
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
@ -131,6 +135,7 @@ class GemmaConfig(PretrainedConfig):
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
|
@ -339,7 +339,6 @@ class FlaxGemmaAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Gemma
|
||||
class FlaxGemmaMLP(nn.Module):
|
||||
config: GemmaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
@ -349,7 +348,18 @@ class FlaxGemmaMLP(nn.Module):
|
||||
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
|
||||
|
||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
||||
self.act = ACT2FN[self.config.hidden_act]
|
||||
if self.config.hidden_activation is None:
|
||||
logger.warning_once(
|
||||
"Gemma's activation function should be approximate GeLU and not exact GeLU. "
|
||||
"Changing the activation function to `gelu_pytorch_tanh`."
|
||||
f"if you want to use the legacy `{self.config.hidden_act}`, "
|
||||
f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
|
||||
" instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
|
||||
)
|
||||
hidden_activation = "gelu_pytorch_tanh"
|
||||
else:
|
||||
hidden_activation = self.config.hidden_activation
|
||||
self.act = ACT2FN[hidden_activation]
|
||||
|
||||
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
||||
|
@ -87,8 +87,11 @@ class GemmaRMSNorm(nn.Module):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * (1 + self.weight)
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
||||
@ -160,7 +163,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -170,7 +172,18 @@ class GemmaMLP(nn.Module):
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
if config.hidden_activation is None:
|
||||
logger.warning_once(
|
||||
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"
|
||||
"Changing the activation function to `gelu_pytorch_tanh`."
|
||||
f"if you want to use the legacy `{config.hidden_act}`, "
|
||||
f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
|
||||
" instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
|
||||
)
|
||||
hidden_activation = "gelu_pytorch_tanh"
|
||||
else:
|
||||
hidden_activation = config.hidden_activation
|
||||
self.act_fn = ACT2FN[hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
@ -894,7 +907,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# normalized
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
Loading…
Reference in New Issue
Block a user