mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
change layernorm code to pytorch's native layer norm
This commit is contained in:
parent
2f9397139d
commit
e13465fb8b
@ -224,20 +224,7 @@ try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
|
Loading…
Reference in New Issue
Block a user