This commit is contained in:
thomwolf 2019-07-23 17:52:20 +02:00
parent 6070b55443
commit 1383c7b87a

View File

@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
try:
from torch.nn import Identity
except ImportError:
# Older PyTorch compatibility
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive.
"""
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input
if not six.PY2:
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
@ -731,7 +745,7 @@ class SequenceSummary(nn.Module):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = nn.Identity()
self.summary = Identity()
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
@ -739,15 +753,15 @@ class SequenceSummary(nn.Module):
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
self.activation = nn.Identity()
self.activation = Identity()
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.activation = nn.Tanh()
self.first_dropout = nn.Identity()
self.first_dropout = Identity()
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = nn.Identity()
self.last_dropout = Identity()
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)