mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
remove the staticmethod used to load the config
This commit is contained in:
parent
d7092d592c
commit
81ee29ee8d
@ -715,7 +715,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertDecoderModel, self).__init__(config)
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.decoder = BertDecoder(config)
|
self.decoder = BertDecoder(config)
|
||||||
@ -1357,28 +1357,27 @@ class Bert2Rnd(BertPreTrainedModel):
|
|||||||
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
|
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
|
||||||
class.
|
class.
|
||||||
"""
|
"""
|
||||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
|
||||||
|
|
||||||
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
|
# Load the configuration
|
||||||
model = cls(config)
|
config = model_kwargs.pop('config', None)
|
||||||
model.encoder = pretrained_encoder
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
|
|
||||||
config = kwargs.pop('config', None)
|
|
||||||
if config is None:
|
if config is None:
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = model_kwargs.pop('cache_dir', None)
|
||||||
force_download = kwargs.pop('force_download', False)
|
force_download = model_kwargs.pop('force_download', False)
|
||||||
config, _ = self.config_class.from_pretrained(
|
config, _ = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_or_path,
|
||||||
*args,
|
*model_args,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
**kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
return config
|
model = cls(config)
|
||||||
|
|
||||||
|
# The encoder is loaded with pretrained weights
|
||||||
|
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||||
|
model.encoder = pretrained_encoder
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
encoder_outputs = self.encoder(input_ids,
|
encoder_outputs = self.encoder(input_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user