Remove unnecessary use of FusedLayerNorm

This commit is contained in:
Santiago Castro 2019-09-22 20:31:36 -04:00 committed by GitHub
parent a2d4950f5c
commit 98dd19b96b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -133,11 +133,7 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except (ImportError, AttributeError) as e:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
BertLayerNorm = torch.nn.LayerNorm
BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.