[Gemma] final fixes to the modeling (#29729)

* gelu_pytorch_tanh

* Force config.hidden_act to be approx gelu

* Gemma bug fixes

* force_use_exact_gelu

* Update configuration_gemma.py

* Update modeling_gemma.py

* update

* update for simpler handling

* nit

* nit

* fixpup

* update

* also update the jax modeling!

* add `"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),`

* fixup

* fix order

* act vs act_fn

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Arthur 2024-03-20 02:47:42 +13:00 committed by GitHub
parent 229ac72b1e
commit 8e2fc52ea3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 10 deletions

View File

@ -78,6 +78,7 @@ ACT2FN = {
"swish": nn.swish, "swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True), "gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu, "quick_gelu": quick_gelu,
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
} }

View File

@ -57,8 +57,11 @@ class GemmaConfig(PretrainedConfig):
`num_attention_heads`. `num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256): head_dim (`int`, *optional*, defaults to 256):
The attention head dimension. The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192): max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with. The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
@ -108,7 +111,8 @@ class GemmaConfig(PretrainedConfig):
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=16, num_key_value_heads=16,
head_dim=256, head_dim=256,
hidden_act="gelu", hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192, max_position_embeddings=8192,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-6, rms_norm_eps=1e-6,
@ -131,6 +135,7 @@ class GemmaConfig(PretrainedConfig):
self.head_dim = head_dim self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache

View File

@ -339,7 +339,6 @@ class FlaxGemmaAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Gemma
class FlaxGemmaMLP(nn.Module): class FlaxGemmaMLP(nn.Module):
config: GemmaConfig config: GemmaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
@ -349,7 +348,18 @@ class FlaxGemmaMLP(nn.Module):
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
kernel_init = jax.nn.initializers.normal(self.config.initializer_range) kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.act = ACT2FN[self.config.hidden_act] if self.config.hidden_activation is None:
logger.warning_once(
"Gemma's activation function should be approximate GeLU and not exact GeLU. "
"Changing the activation function to `gelu_pytorch_tanh`."
f"if you want to use the legacy `{self.config.hidden_act}`, "
f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
" instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
)
hidden_activation = "gelu_pytorch_tanh"
else:
hidden_activation = self.config.hidden_activation
self.act = ACT2FN[hidden_activation]
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)

View File

@ -87,8 +87,11 @@ class GemmaRMSNorm(nn.Module):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
output = self._norm(x.float()).type_as(x) output = self._norm(x.float())
return output * (1 + self.weight) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
@ -160,7 +163,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed return q_embed, k_embed
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -170,7 +172,18 @@ class GemmaMLP(nn.Module):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] if config.hidden_activation is None:
logger.warning_once(
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"
"Changing the activation function to `gelu_pytorch_tanh`."
f"if you want to use the legacy `{config.hidden_act}`, "
f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
" instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
)
hidden_activation = "gelu_pytorch_tanh"
else:
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x): def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
@ -894,7 +907,10 @@ class GemmaModel(GemmaPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# normalized # normalized
hidden_states = hidden_states * (self.config.hidden_size**0.5) # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None