mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
add fallback path for apex used in modeling.py
This commit is contained in:
parent
c8ea286048
commit
3b0a14b761
@ -31,10 +31,6 @@ import shutil
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
try:
|
|
||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
|
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
|
||||||
@ -157,22 +153,24 @@ class BertConfig(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||||
|
except ImportError:
|
||||||
|
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||||
|
class BertLayerNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-12):
|
||||||
|
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||||
|
"""
|
||||||
|
super(BertLayerNorm, self).__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
class BertLayerNorm(nn.Module):
|
def forward(self, x):
|
||||||
def __init__(self, config, variance_epsilon=1e-12):
|
u = x.mean(-1, keepdim=True)
|
||||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||||
"""
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||||
super(BertLayerNorm, self).__init__()
|
return self.weight * x + self.bias
|
||||||
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
|
|
||||||
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
|
|
||||||
self.variance_epsilon = variance_epsilon
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
u = x.mean(-1, keepdim=True)
|
|
||||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
||||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
||||||
return self.gamma * x + self.beta
|
|
||||||
|
|
||||||
|
|
||||||
class BertEmbeddings(nn.Module):
|
class BertEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings.
|
"""Construct the embeddings from word, position and token_type embeddings.
|
||||||
@ -185,7 +183,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
# any TensorFlow checkpoint file
|
# any TensorFlow checkpoint file
|
||||||
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None):
|
def forward(self, input_ids, token_type_ids=None):
|
||||||
@ -260,7 +258,7 @@ class BertSelfOutput(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertSelfOutput, self).__init__()
|
super(BertSelfOutput, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states, input_tensor):
|
||||||
@ -299,7 +297,7 @@ class BertOutput(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertOutput, self).__init__()
|
super(BertOutput, self).__init__()
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states, input_tensor):
|
||||||
@ -361,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||||
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
@ -443,12 +441,9 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
elif isinstance(module, FusedLayerNorm):
|
elif isinstance(module, BertLayerNorm):
|
||||||
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
elif isinstance(module, BertLayerNorm):
|
|
||||||
module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
||||||
module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user