mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix model2model
This commit is contained in:
parent
8cd56e3036
commit
56e2ee4ead
@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedSeq2seq(PreTrainedModel):
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||
instantiated as a Seq2seq model with one of the base model classes of
|
||||
@ -43,7 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel):
|
||||
self.decoder = decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
|
||||
@ -177,8 +177,8 @@ class PreTrainedSeq2seq(PreTrainedModel):
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
def __init__(self):
|
||||
super(Model2Model, self).__init__()
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model2Model, self).__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@ -197,7 +197,14 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
by a model-specific keyword (bert, )...
|
||||
"""
|
||||
# self._tie_or_clone_weights(self.encoder, self.decoder)
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
class Model2LSTM(PreTrainedSeq2seq):
|
||||
|
Loading…
Reference in New Issue
Block a user