Add common properties input_embeddings and output_embeddings

This commit is contained in:
thomwolf 2019-11-04 12:28:56 +01:00
parent 8a62835577
commit 9b45d0f878
12 changed files with 179 additions and 153 deletions

View File

@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel):
super(XxxForMaskedLM, self).__init__(config)
self.transformer = XxxModel(config)
self.cls = XxxOnlyMLMHead(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.transformer.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):

View File

@ -601,12 +601,14 @@ class BertModel(BertPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -750,14 +752,10 @@ class BertForPreTraining(BertPreTrainedModel):
self.cls = BertPreTrainingHeads(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None):
@ -830,14 +828,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self.cls = BertOnlyMLMHead(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):

View File

@ -289,10 +289,14 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
@property
def input_embeddings(self):
return self.w
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -449,13 +453,10 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):

View File

@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None
base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module):
""" Initialize the weights.
"""
@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.init_weights()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids,

View File

@ -357,10 +357,14 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
@property
def input_embeddings(self):
return self.wte
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -514,14 +518,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
@ -622,14 +622,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.multiple_choice_head = SequenceSummary(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None):

View File

@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
@property
def input_embeddings(self):
return self.tokens_embed
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.multiple_choice_head = SequenceSummary(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.tokens_embed)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
mc_token_ids=None, lm_labels=None, mc_labels=None):

View File

@ -169,6 +169,10 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config)
self.init_weights()
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
@ -213,13 +217,10 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.lm_head = RobertaLMHead(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):

View File

@ -639,9 +639,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
@property
def input_embeddings(self):
return self.word_emb
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings
def backward_compatible(self):
self.sample_softmax = -1
@ -826,7 +831,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
self.init_weights()
self.tie_weights()
def tie_weights(self):
"""

View File

@ -83,6 +83,77 @@ class PreTrainedModel(nn.Module):
# Save config in model
self.config = config
@property
def base_model(self):
return getattr(self, self.base_model_prefix, self)
@property
def input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)
return base_model.input_embeddings
@property
def output_embeddings(self):
return None # Overwrite for models with output embeddings
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
if self.output_embeddings is not None:
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
if self.config.torchscript:
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
'constant',
0
)
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings``
Pointer to the input tokens Embeddings Module of the model
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
# Tie weights again if needed
if hasattr(self, 'tie_weights'):
self.tie_weights()
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
@ -117,50 +188,6 @@ class PreTrainedModel(nn.Module):
return new_embeddings
def _tie_or_clone_weights(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone())
else:
first_module.weight = second_module.weight
if hasattr(first_module, 'bias') and first_module.bias is not None:
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings``
Pointer to the input tokens Embeddings Module of the model
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
# Tie weights again if needed
if hasattr(self, 'tie_weights'):
self.tie_weights()
return model_embeds
def init_weights(self):
""" Initialize and prunes weights if needed. """
# Initialize weights
@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module):
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
base_model._prune_heads(heads_to_prune)
self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it

View File

@ -407,10 +407,14 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
@property
def input_embeddings(self):
return self.embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -618,12 +622,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.pred_layer = XLMPredLayer(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
@property
def output_embeddings(self):
return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None):

View File

@ -611,10 +611,14 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
@property
def input_embeddings(self):
return self.word_embedding
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
@ -918,12 +922,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
@property
def output_embeddings(self):
return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None):

View File

@ -463,6 +463,15 @@ class CommonTestCases:
self.assertTrue(models_equal)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertTrue(hasattr(model, 'input_embeddings'))
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10))
self.assertTrue(hasattr(model, 'output_embeddings'))
def test_tie_model_weights(self):
if not self.test_torchscript:
return