mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fix #869
This commit is contained in:
parent
6070b55443
commit
1383c7b87a
@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
|||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn import Identity
|
||||||
|
except ImportError:
|
||||||
|
# Older PyTorch compatibility
|
||||||
|
class Identity(nn.Module):
|
||||||
|
r"""A placeholder identity operator that is argument-insensitive.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
if not six.PY2:
|
if not six.PY2:
|
||||||
def add_start_docstrings(*docstr):
|
def add_start_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
@ -731,7 +745,7 @@ class SequenceSummary(nn.Module):
|
|||||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.summary = nn.Identity()
|
self.summary = Identity()
|
||||||
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
||||||
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
num_classes = config.num_labels
|
num_classes = config.num_labels
|
||||||
@ -739,15 +753,15 @@ class SequenceSummary(nn.Module):
|
|||||||
num_classes = config.hidden_size
|
num_classes = config.hidden_size
|
||||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||||
|
|
||||||
self.activation = nn.Identity()
|
self.activation = Identity()
|
||||||
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
self.first_dropout = nn.Identity()
|
self.first_dropout = Identity()
|
||||||
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
||||||
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
self.last_dropout = nn.Identity()
|
self.last_dropout = Identity()
|
||||||
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
||||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user