mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Merge pull request #1434 from bryant1410/patch-1
Remove unnecessary use of FusedLayerNorm in XLNet
This commit is contained in:
commit
8aa3b753bd
@ -188,11 +188,8 @@ def swish(x):
|
|||||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||||
|
|
||||||
|
|
||||||
try:
|
XLNetLayerNorm = nn.LayerNorm
|
||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
|
||||||
from torch.nn import LayerNorm as XLNetLayerNorm
|
|
||||||
|
|
||||||
class XLNetRelativeAttention(nn.Module):
|
class XLNetRelativeAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
Loading…
Reference in New Issue
Block a user