switch from properties to methods

This commit is contained in:
thomwolf 2019-11-04 15:34:10 +01:00
parent 9b45d0f878
commit 1724cee8c4
12 changed files with 70 additions and 75 deletions

View File

@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.w
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):

View File

@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.wte
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.tokens_embed
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -169,10 +169,11 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config)
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_emebddings = value
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.word_emb
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings
def backward_compatible(self):

View File

@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
def base_model(self):
return getattr(self, self.base_model_prefix, self)
@property
def input_embeddings(self):
def get_input_embeddings(self):
""" Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
return base_model.input_embeddings
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
@property
def output_embeddings(self):
def set_input_embeddings(self, value):
""" Set model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
if self.output_embeddings is not None:
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not
@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.

View File

@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,

View File

@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights()
@property
def input_embeddings(self):
def get_input_embeddings(self):
return self.word_embedding
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune):
@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.init_weights()
@property
def output_embeddings(self):
def get_output_embeddings(self):
return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,

View File

@ -429,6 +429,12 @@ class CommonTestCases:
list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size])
def test_debug(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model_embed = model.resize_token_embeddings(config.vocab_size + 10)
def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
@ -468,9 +474,9 @@ class CommonTestCases:
for model_class in self.all_model_classes:
model = model_class(config)
self.assertTrue(hasattr(model, 'input_embeddings'))
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10))
self.assertTrue(hasattr(model, 'output_embeddings'))
model.get_input_embeddings()
model.set_input_embeddings(torch.nn.Embedding(10, 10))
model.get_output_embeddings()
def test_tie_model_weights(self):
if not self.test_torchscript: