remove the staticmethod used to load the config

This commit is contained in:
Rémi Louf 2019-10-10 14:13:37 +02:00
parent d7092d592c
commit 81ee29ee8d

View File

@ -715,7 +715,7 @@ class BertDecoderModel(BertPreTrainedModel):
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
super(BertDecoderModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.decoder = BertDecoder(config)
@ -1357,28 +1357,27 @@ class Bert2Rnd(BertPreTrainedModel):
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
class.
"""
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
model = cls(config)
model.encoder = pretrained_encoder
return model
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.pop('config', None)
# Load the configuration
config = model_kwargs.pop('config', None)
if config is None:
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
config, _ = self.config_class.from_pretrained(
pretrained_model_name_or_path,
*args,
cache_dir = model_kwargs.pop('cache_dir', None)
force_download = model_kwargs.pop('force_download', False)
config, _ = cls.config_class.from_pretrained(
pretrained_model_or_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**kwargs
**model_kwargs
)
return config
model = cls(config)
# The encoder is loaded with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder
return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,