mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
remove the staticmethod used to load the config
This commit is contained in:
parent
d7092d592c
commit
81ee29ee8d
@ -715,7 +715,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertModel, self).__init__(config)
|
||||
super(BertDecoderModel, self).__init__(config)
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.decoder = BertDecoder(config)
|
||||
@ -1357,28 +1357,27 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
|
||||
class.
|
||||
"""
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
|
||||
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model = cls(config)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
return model
|
||||
|
||||
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
|
||||
config = kwargs.pop('config', None)
|
||||
# Load the configuration
|
||||
config = model_kwargs.pop('config', None)
|
||||
if config is None:
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
config, _ = self.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*args,
|
||||
cache_dir = model_kwargs.pop('cache_dir', None)
|
||||
force_download = model_kwargs.pop('force_download', False)
|
||||
config, _ = cls.config_class.from_pretrained(
|
||||
pretrained_model_or_path,
|
||||
*model_args,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**kwargs
|
||||
**model_kwargs
|
||||
)
|
||||
return config
|
||||
model = cls(config)
|
||||
|
||||
# The encoder is loaded with pretrained weights
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
encoder_outputs = self.encoder(input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user