mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1721 from huggingface/common_attributes
Add common getter and setter for input_embeddings & output_embeddings
This commit is contained in:
commit
c8f2712199
@ -280,12 +280,13 @@ class XxxModel(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
@property
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -376,17 +377,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
super(XxxForMaskedLM, self).__init__(config)
|
||||
|
||||
self.transformer = XxxModel(config)
|
||||
self.cls = XxxOnlyMLMHead(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.transformer.embeddings.word_embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None):
|
||||
|
@ -601,12 +601,12 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -750,14 +750,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None, next_sentence_label=None):
|
||||
@ -830,14 +825,9 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||
|
@ -289,10 +289,12 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
|
||||
def get_input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.w = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -449,13 +451,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
|
@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
@ -424,12 +421,12 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -511,16 +508,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.vocab_projector,
|
||||
self.distilbert.embeddings.word_embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.vocab_projector
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
||||
dlbrt_output = self.distilbert(input_ids=input_ids,
|
||||
|
@ -357,10 +357,12 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -514,14 +516,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.wte)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
@ -622,14 +619,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.wte)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||
|
@ -360,10 +360,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||
def get_input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.tokens_embed = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -489,14 +491,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
@ -583,14 +580,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||
|
@ -169,6 +169,11 @@ class RobertaModel(BertModel):
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
@ -213,13 +218,9 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
self.lm_head = RobertaLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None):
|
||||
|
@ -639,9 +639,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
def get_input_embeddings(self):
|
||||
return self.word_emb
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.word_emb = new_embeddings
|
||||
|
||||
def backward_compatible(self):
|
||||
self.sample_softmax = -1
|
||||
|
||||
@ -826,7 +829,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||
config.cutoffs, div_val=config.div_val)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
|
@ -83,6 +83,94 @@ class PreTrainedModel(nn.Module):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
""" Set model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
|
||||
if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
@ -117,50 +205,6 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
if hasattr(first_module, 'bias') and first_module.bias is not None:
|
||||
first_module.bias.data = torch.nn.functional.pad(
|
||||
first_module.bias.data,
|
||||
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
||||
'constant',
|
||||
0
|
||||
)
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def init_weights(self):
|
||||
""" Initialize and prunes weights if needed. """
|
||||
# Initialize weights
|
||||
@ -170,6 +214,9 @@ class PreTrainedModel(nn.Module):
|
||||
if self.config.pruned_heads:
|
||||
self.prune_heads(self.config.pruned_heads)
|
||||
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
@ -178,14 +225,12 @@ class PreTrainedModel(nn.Module):
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||
for layer, heads in heads_to_prune.items():
|
||||
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
||||
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
|
@ -407,10 +407,12 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -618,12 +620,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
self.pred_layer = XLMPredLayer(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
|
||||
def get_output_embeddings(self):
|
||||
return self.pred_layer.proj
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||
lengths=None, cache=None, head_mask=None, labels=None):
|
||||
|
@ -611,10 +611,12 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
||||
def get_input_embeddings(self):
|
||||
return self.word_embedding
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.word_embedding = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -918,12 +920,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
|
||||
|
@ -38,6 +38,7 @@ else:
|
||||
|
||||
|
||||
class AutoModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
@ -52,6 +53,7 @@ class AutoModelTest(unittest.TestCase):
|
||||
for value in loading_info.values():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_lmhead_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
@ -64,6 +66,7 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
@ -76,6 +79,7 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForSequenceClassification)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_question_answering_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -463,6 +463,15 @@ class CommonTestCases:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.get_input_embeddings()
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
model.get_output_embeddings()
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
@ -477,11 +486,11 @@ class CommonTestCases:
|
||||
return equal
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not hasattr(model_class, 'tie_weights'):
|
||||
continue
|
||||
|
||||
config.torchscript = True
|
||||
model_not_tied = model_class(config)
|
||||
if model_not_tied.get_output_embeddings() is None:
|
||||
continue
|
||||
|
||||
params_not_tied = list(model_not_tied.parameters())
|
||||
|
||||
config_tied = copy.deepcopy(config)
|
||||
@ -688,6 +697,7 @@ class CommonTestCases:
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_presents(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def run_slow_tests(self):
|
||||
self.create_and_check_model_from_pretrained()
|
||||
|
||||
@ -761,6 +771,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -27,6 +27,7 @@ else:
|
||||
|
||||
|
||||
class EncoderDecoderModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model2model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -26,6 +26,7 @@ from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CON
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import (BasicTokenizer,
|
||||
@ -125,6 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
self.assertFalse(_is_punctuation(u"A"))
|
||||
self.assertFalse(_is_punctuation(u" "))
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_distilbert import (DistilBertTokenizer)
|
||||
@ -30,6 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
import pytest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
|
||||
@ -78,6 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
|
||||
|
@ -18,11 +18,13 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import six
|
||||
import pytest
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
class TokenizerUtilsTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def check_tokenizer_from_pretrained(self, tokenizer_class):
|
||||
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
|
||||
for model_name in s3_models[:1]:
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
@ -66,6 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
|
||||
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||
|
||||
@ -89,6 +90,7 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user