Merge pull request #1434 from bryant1410/patch-1

Remove unnecessary use of FusedLayerNorm in XLNet
This commit is contained in:
Thomas Wolf 2019-10-15 09:44:19 +02:00 committed by GitHub
commit 8aa3b753bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -188,11 +188,8 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
except (ImportError, AttributeError) as e:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
from torch.nn import LayerNorm as XLNetLayerNorm
XLNetLayerNorm = nn.LayerNorm
class XLNetRelativeAttention(nn.Module):
def __init__(self, config):