remove the staticmethod used to load the config

This commit is contained in:
Rémi Louf 2019-10-10 14:13:37 +02:00
parent d7092d592c
commit 81ee29ee8d

View File

@ -715,7 +715,7 @@ class BertDecoderModel(BertPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(BertModel, self).__init__(config) super(BertDecoderModel, self).__init__(config)
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.decoder = BertDecoder(config) self.decoder = BertDecoder(config)
@ -1357,28 +1357,27 @@ class Bert2Rnd(BertPreTrainedModel):
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
class. class.
""" """
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs) # Load the configuration
model = cls(config) config = model_kwargs.pop('config', None)
model.encoder = pretrained_encoder
return model
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.pop('config', None)
if config is None: if config is None:
cache_dir = kwargs.pop('cache_dir', None) cache_dir = model_kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = model_kwargs.pop('force_download', False)
config, _ = self.config_class.from_pretrained( config, _ = cls.config_class.from_pretrained(
pretrained_model_name_or_path, pretrained_model_or_path,
*args, *model_args,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
**kwargs **model_kwargs
) )
return config model = cls(config)
# The encoder is loaded with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder
return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids, encoder_outputs = self.encoder(input_ids,