mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
apex LayerNorm
This commit is contained in:
parent
b9c77b98d5
commit
9c35c132fa
@ -217,7 +217,7 @@ class PositionwiseFF(nn.Module):
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self.layer_norm = LayerNorm(d_model)
|
||||
|
||||
self.pre_lnorm = pre_lnorm
|
||||
|
||||
@ -254,7 +254,7 @@ class MultiHeadAttn(nn.Module):
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self.layer_norm = LayerNorm(d_model)
|
||||
|
||||
self.scale = 1 / (d_head ** 0.5)
|
||||
|
||||
@ -335,7 +335,7 @@ class RelMultiHeadAttn(nn.Module):
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self.layer_norm = LayerNorm(d_model)
|
||||
|
||||
self.scale = 1 / (d_head ** 0.5)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user