GPT-2 option to avoid predicting special tokens

This commit is contained in:
thomwolf 2019-05-07 16:25:53 +02:00
parent e211785ada
commit d1b6979aa5
2 changed files with 17 additions and 13 deletions

View File

@ -115,6 +115,7 @@ class GPT2Config(object):
n_head=12, n_head=12,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True
): ):
"""Constructs GPT2Config. """Constructs GPT2Config.
@ -130,6 +131,7 @@ class GPT2Config(object):
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
@ -147,6 +149,7 @@ class GPT2Config(object):
self.n_head = n_head self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
@ -297,18 +300,20 @@ class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__() super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
embed_shape = model_embeddings_weights.shape self.predict_special_tokens = predict_special_tokens
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
if not self.predict_special_tokens:
lm_logits = lm_logits[..., :self.vocab_size]
return lm_logits return lm_logits
@ -353,9 +358,6 @@ class GPT2PreTrainedModel(nn.Module):
) )
self.config = config self.config = config
def set_num_special_tokens(self, num_special_tokens):
pass
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
@ -650,12 +652,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
@ -729,12 +732,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)

View File

@ -263,8 +263,8 @@ class GPT2Tokenizer(object):
def encode(self, text): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens): def decode(self, tokens, skip_special_tokens=False):
text = ''.join([self.decoder[token] for token in tokens]) text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text return text