mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #17 from lukovnikov/master
activation function in BERTIntermediate
This commit is contained in:
commit
8513741b57
13
modeling.py
13
modeling.py
@ -25,6 +25,7 @@ import six
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from six import string_types
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
@ -34,6 +35,13 @@ def gelu(x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
@ -60,7 +68,7 @@ class BertConfig(object):
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler.
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
@ -237,7 +245,8 @@ class BERTIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = gelu
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, string_types) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user