mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
added n_inner
argument to gpt2 config (#6296)
This commit is contained in:
parent
0a0d53dcf8
commit
2f2aa0c89c
@ -59,6 +59,8 @@ class GPT2Config(PretrainedConfig):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (:obj:`int`, optional, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_inner (:obj:`int`, optional, defaults to None):
|
||||
Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd
|
||||
activation_function (:obj:`str`, optional, defaults to 'gelu'):
|
||||
Activation function selected in the list ["relu", "swish", "gelu", "tanh", "gelu_new"].
|
||||
resid_pdrop (:obj:`float`, optional, defaults to 0.1):
|
||||
@ -122,6 +124,7 @@ class GPT2Config(PretrainedConfig):
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
n_inner=None,
|
||||
activation_function="gelu_new",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
@ -145,6 +148,7 @@ class GPT2Config(PretrainedConfig):
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.activation_function = activation_function
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
|
@ -240,10 +240,11 @@ class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
|
||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
self.mlp = MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
|
||||
|
@ -194,10 +194,11 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
def __init__(self, n_ctx, config, scale=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
nx = config.n_embd
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
|
||||
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
|
||||
self.attn = TFAttention(nx, n_ctx, config, scale, name="attn")
|
||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||
self.mlp = TFMLP(inner_dim, config, name="mlp")
|
||||
|
||||
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||
a = self.ln_1(x)
|
||||
|
Loading…
Reference in New Issue
Block a user