diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 0f294400e5f..da373603420 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -78,6 +78,7 @@ ACT2FN = { "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), "quick_gelu": quick_gelu, + "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), } diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 2e758bcaf5c..cf7be344e82 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -57,8 +57,11 @@ class GemmaConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the decoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -108,7 +111,8 @@ class GemmaConfig(PretrainedConfig): num_attention_heads=16, num_key_value_heads=16, head_dim=256, - hidden_act="gelu", + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -131,6 +135,7 @@ class GemmaConfig(PretrainedConfig): self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act + self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache diff --git a/src/transformers/models/gemma/modeling_flax_gemma.py b/src/transformers/models/gemma/modeling_flax_gemma.py index 6dd4f662904..235f65680fa 100644 --- a/src/transformers/models/gemma/modeling_flax_gemma.py +++ b/src/transformers/models/gemma/modeling_flax_gemma.py @@ -339,7 +339,6 @@ class FlaxGemmaAttention(nn.Module): return outputs -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Gemma class FlaxGemmaMLP(nn.Module): config: GemmaConfig dtype: jnp.dtype = jnp.float32 @@ -349,7 +348,18 @@ class FlaxGemmaMLP(nn.Module): inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - self.act = ACT2FN[self.config.hidden_act] + if self.config.hidden_activation is None: + logger.warning_once( + "Gemma's activation function should be approximate GeLU and not exact GeLU. " + "Changing the activation function to `gelu_pytorch_tanh`." + f"if you want to use the legacy `{self.config.hidden_act}`, " + f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` " + " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + hidden_activation = "gelu_pytorch_tanh" + else: + hidden_activation = self.config.hidden_activation + self.act = ACT2FN[hidden_activation] self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 90b63a5ec81..8ec5d64ade1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -87,8 +87,11 @@ class GemmaRMSNorm(nn.Module): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * (1 + self.weight) + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) @@ -160,7 +163,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -170,7 +172,18 @@ class GemmaMLP(nn.Module): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] + if config.hidden_activation is None: + logger.warning_once( + "Gemma's activation function should be approximate GeLU and not exact GeLU.\n" + "Changing the activation function to `gelu_pytorch_tanh`." + f"if you want to use the legacy `{config.hidden_act}`, " + f"edit the `model.config` to set `hidden_activation={config.hidden_act}` " + " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + hidden_activation = "gelu_pytorch_tanh" + else: + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -894,7 +907,10 @@ class GemmaModel(GemmaPreTrainedModel): hidden_states = inputs_embeds # normalized - hidden_states = hidden_states * (self.config.hidden_size**0.5) + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer # decoder layers all_hidden_states = () if output_hidden_states else None