mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
GPT-2 option to avoid predicting special tokens
This commit is contained in:
parent
e211785ada
commit
d1b6979aa5
@ -115,6 +115,7 @@ class GPT2Config(object):
|
||||
n_head=12,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
predict_special_tokens=True
|
||||
):
|
||||
"""Constructs GPT2Config.
|
||||
|
||||
@ -130,6 +131,7 @@ class GPT2Config(object):
|
||||
layer_norm_epsilon: epsilon to use in the layer norm layers
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
@ -147,6 +149,7 @@ class GPT2Config(object):
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.predict_special_tokens = predict_special_tokens
|
||||
else:
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
@ -297,18 +300,20 @@ class GPT2LMHead(nn.Module):
|
||||
def __init__(self, model_embeddings_weights, config):
|
||||
super(GPT2LMHead, self).__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.vocab_size = config.vocab_size
|
||||
self.predict_special_tokens = config.predict_special_tokens
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights):
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||
self.predict_special_tokens = predict_special_tokens
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# Truncated Language modeling logits (we remove the last token)
|
||||
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
if not self.predict_special_tokens:
|
||||
lm_logits = lm_logits[..., :self.vocab_size]
|
||||
return lm_logits
|
||||
|
||||
|
||||
@ -353,9 +358,6 @@ class GPT2PreTrainedModel(nn.Module):
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
pass
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
@ -650,12 +652,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
""" Update input and output embeddings with new embedding matrice
|
||||
Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
@ -729,12 +732,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
""" Update input and output embeddings with new embedding matrice
|
||||
Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
|
@ -263,8 +263,8 @@ class GPT2Tokenizer(object):
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
def decode(self, tokens, skip_special_tokens=False):
|
||||
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user