fix model2model

This commit is contained in:
thomwolf 2019-10-17 16:33:31 +02:00
parent 8cd56e3036
commit 56e2ee4ead

View File

@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(PreTrainedModel):
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
@ -43,7 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel):
self.decoder = decoder
@classmethod
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
@ -177,8 +177,8 @@ class PreTrainedSeq2seq(PreTrainedModel):
class Model2Model(PreTrainedSeq2seq):
def __init__(self):
super(Model2Model, self).__init__()
def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights()
def tie_weights(self):
@ -197,7 +197,14 @@ class Model2Model(PreTrainedSeq2seq):
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
raise NotImplementedError
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
**kwargs)
return model
class Model2LSTM(PreTrainedSeq2seq):