mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Fix #869
This commit is contained in:
parent
6070b55443
commit
1383c7b87a
@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
# Older PyTorch compatibility
|
||||
class Identity(nn.Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
if not six.PY2:
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
@ -731,7 +745,7 @@ class SequenceSummary(nn.Module):
|
||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||
raise NotImplementedError
|
||||
|
||||
self.summary = nn.Identity()
|
||||
self.summary = Identity()
|
||||
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
||||
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
||||
num_classes = config.num_labels
|
||||
@ -739,15 +753,15 @@ class SequenceSummary(nn.Module):
|
||||
num_classes = config.hidden_size
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
self.activation = nn.Identity()
|
||||
self.activation = Identity()
|
||||
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
self.first_dropout = nn.Identity()
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
||||
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||
|
||||
self.last_dropout = nn.Identity()
|
||||
self.last_dropout = Identity()
|
||||
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user