mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
GPT-2 option to avoid predicting special tokens
This commit is contained in:
parent
e211785ada
commit
d1b6979aa5
@ -115,6 +115,7 @@ class GPT2Config(object):
|
|||||||
n_head=12,
|
n_head=12,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
predict_special_tokens=True
|
||||||
):
|
):
|
||||||
"""Constructs GPT2Config.
|
"""Constructs GPT2Config.
|
||||||
|
|
||||||
@ -130,6 +131,7 @@ class GPT2Config(object):
|
|||||||
layer_norm_epsilon: epsilon to use in the layer norm layers
|
layer_norm_epsilon: epsilon to use in the layer norm layers
|
||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
|
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
@ -147,6 +149,7 @@ class GPT2Config(object):
|
|||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.predict_special_tokens = predict_special_tokens
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"First argument must be either a vocabulary size (int)"
|
"First argument must be either a vocabulary size (int)"
|
||||||
@ -297,18 +300,20 @@ class GPT2LMHead(nn.Module):
|
|||||||
def __init__(self, model_embeddings_weights, config):
|
def __init__(self, model_embeddings_weights, config):
|
||||||
super(GPT2LMHead, self).__init__()
|
super(GPT2LMHead, self).__init__()
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.predict_special_tokens = config.predict_special_tokens
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
def set_embeddings_weights(self, model_embeddings_weights):
|
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||||
embed_shape = model_embeddings_weights.shape
|
self.predict_special_tokens = predict_special_tokens
|
||||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
# Truncated Language modeling logits (we remove the last token)
|
|
||||||
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
|
||||||
lm_logits = self.decoder(hidden_state)
|
lm_logits = self.decoder(hidden_state)
|
||||||
|
if not self.predict_special_tokens:
|
||||||
|
lm_logits = lm_logits[..., :self.vocab_size]
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
@ -353,9 +358,6 @@ class GPT2PreTrainedModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_weights(self, module):
|
def init_weights(self, module):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
@ -650,12 +652,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||||
""" Update input and output embeddings with new embedding matrice
|
""" Update input and output embeddings with new embedding matrice
|
||||||
Make sure we are sharing the embeddings
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
@ -729,12 +732,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||||
""" Update input and output embeddings with new embedding matrice
|
""" Update input and output embeddings with new embedding matrice
|
||||||
Make sure we are sharing the embeddings
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
|
@ -263,8 +263,8 @@ class GPT2Tokenizer(object):
|
|||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
def decode(self, tokens):
|
def decode(self, tokens, skip_special_tokens=False):
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user