mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add common properties input_embeddings and output_embeddings
This commit is contained in:
parent
8a62835577
commit
9b45d0f878
@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
super(XxxForMaskedLM, self).__init__(config)
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.cls = XxxOnlyMLMHead(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.transformer.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None):
|
||||
|
@ -601,12 +601,14 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -750,14 +752,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None, next_sentence_label=None):
|
||||
@ -830,14 +828,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||
|
@ -289,10 +289,14 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.w = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -449,13 +453,10 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
|
@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.vocab_projector,
|
||||
self.distilbert.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.vocab_projector
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
||||
dlbrt_output = self.distilbert(input_ids=input_ids,
|
||||
|
@ -357,10 +357,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -514,14 +518,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.wte)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
@ -622,14 +622,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.wte)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||
|
@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.tokens_embed = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||
|
@ -169,6 +169,10 @@ class RobertaModel(BertModel):
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
@ -213,13 +217,10 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
self.lm_head = RobertaLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None):
|
||||
|
@ -639,9 +639,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.word_emb
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.word_emb = new_embeddings
|
||||
|
||||
def backward_compatible(self):
|
||||
self.sample_softmax = -1
|
||||
|
||||
@ -826,7 +831,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||
config.cutoffs, div_val=config.div_val)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
|
@ -83,6 +83,77 @@ class PreTrainedModel(nn.Module):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
return base_model.input_embeddings
|
||||
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
if self.output_embeddings is not None:
|
||||
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
|
||||
if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.input_embeddings
|
||||
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
return self.input_embeddings
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
@ -117,50 +188,6 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
if hasattr(first_module, 'bias') and first_module.bias is not None:
|
||||
first_module.bias.data = torch.nn.functional.pad(
|
||||
first_module.bias.data,
|
||||
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def init_weights(self):
|
||||
""" Initialize and prunes weights if needed. """
|
||||
# Initialize weights
|
||||
@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module):
|
||||
if self.config.pruned_heads:
|
||||
self.prune_heads(self.config.pruned_heads)
|
||||
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module):
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||
for layer, heads in heads_to_prune.items():
|
||||
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
||||
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
|
@ -407,10 +407,14 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -618,12 +622,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
self.pred_layer = XLMPredLayer(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.pred_layer.proj
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||
lengths=None, cache=None, head_mask=None, labels=None):
|
||||
|
@ -611,10 +611,14 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.word_embedding
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.word_embedding = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -918,12 +922,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
|
||||
|
@ -463,6 +463,15 @@ class CommonTestCases:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertTrue(hasattr(model, 'input_embeddings'))
|
||||
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10))
|
||||
self.assertTrue(hasattr(model, 'output_embeddings'))
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user