diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 324cdc17c9d..3f1df0a49dc 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = 'model.ckpt' +try: + from torch.nn import Identity +except ImportError: + # Older PyTorch compatibility + class Identity(nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + """ + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + def forward(self, input): + return input + + if not six.PY2: def add_start_docstrings(*docstr): def docstring_decorator(fn): @@ -731,7 +745,7 @@ class SequenceSummary(nn.Module): # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError - self.summary = nn.Identity() + self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels @@ -739,15 +753,15 @@ class SequenceSummary(nn.Module): num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) - self.activation = nn.Identity() + self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() - self.first_dropout = nn.Identity() + self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) - self.last_dropout = nn.Identity() + self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)