mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enhancing Code Readability and Maintainability with Simplified Activation Function Selection. (#28349)
* Little bit change code in get_activation() * proper area to deffine gelu_activation() in this two file * Fix github issue * Mistake some typo * My mistake to self using to call config * Reformat my two file * Update src/transformers/activations.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/electra/modeling_electra.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/convbert/modeling_convbert.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Rename gelu_act to activatioin --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
3eddda1111
commit
53cffeb33c
@ -856,12 +856,13 @@ class ConvBertGeneratorPredictions(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.activation = get_activation("gelu")
|
||||
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
|
||||
def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.dense(generator_hidden_states)
|
||||
hidden_states = get_activation("gelu")(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
@ -631,12 +631,13 @@ class ElectraDiscriminatorPredictions(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = get_activation(config.hidden_act)
|
||||
self.dense_prediction = nn.Linear(config.hidden_size, 1)
|
||||
self.config = config
|
||||
|
||||
def forward(self, discriminator_hidden_states):
|
||||
hidden_states = self.dense(discriminator_hidden_states)
|
||||
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
logits = self.dense_prediction(hidden_states).squeeze(-1)
|
||||
|
||||
return logits
|
||||
@ -648,12 +649,13 @@ class ElectraGeneratorPredictions(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.activation = get_activation("gelu")
|
||||
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
|
||||
def forward(self, generator_hidden_states):
|
||||
hidden_states = self.dense(generator_hidden_states)
|
||||
hidden_states = get_activation("gelu")(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -933,6 +935,7 @@ class ElectraClassificationHead(nn.Module):
|
||||
classifier_dropout = (
|
||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
||||
)
|
||||
self.activation = get_activation("gelu")
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
@ -940,7 +943,7 @@ class ElectraClassificationHead(nn.Module):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = get_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here
|
||||
x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user