mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
fix model2model
This commit is contained in:
parent
8cd56e3036
commit
56e2ee4ead
@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedSeq2seq(PreTrainedModel):
|
class PreTrainedSeq2seq(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||||
instantiated as a Seq2seq model with one of the base model classes of
|
instantiated as a Seq2seq model with one of the base model classes of
|
||||||
@ -43,7 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
|
||||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||||
of the library from pre-trained model checkpoints.
|
of the library from pre-trained model checkpoints.
|
||||||
|
|
||||||
@ -177,8 +177,8 @@ class PreTrainedSeq2seq(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class Model2Model(PreTrainedSeq2seq):
|
class Model2Model(PreTrainedSeq2seq):
|
||||||
def __init__(self):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Model2Model, self).__init__()
|
super(Model2Model, self).__init__(*args, **kwargs)
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@ -197,7 +197,14 @@ class Model2Model(PreTrainedSeq2seq):
|
|||||||
by a model-specific keyword (bert, )...
|
by a model-specific keyword (bert, )...
|
||||||
"""
|
"""
|
||||||
# self._tie_or_clone_weights(self.encoder, self.decoder)
|
# self._tie_or_clone_weights(self.encoder, self.decoder)
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||||
|
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
**kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class Model2LSTM(PreTrainedSeq2seq):
|
class Model2LSTM(PreTrainedSeq2seq):
|
||||||
|
Loading…
Reference in New Issue
Block a user