added tests + fixed losses

This commit is contained in:
thomwolf 2019-01-08 16:24:23 +01:00
parent eed51c5bdf
commit 3cf12b235a
4 changed files with 484 additions and 225 deletions

View File

@ -549,7 +549,7 @@ class BertPreTrainedModel(nn.Module):
model.__class__.__name__, unexpected_keys)) model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
if tempdir: if tempdir:
# Clean up temp dir # Clean up temp dir
shutil.rmtree(tempdir) shutil.rmtree(tempdir)

View File

@ -48,12 +48,10 @@ class OpenAIGPTConfig(object):
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
intermediate_size=3072,
afn="gelu", afn="gelu",
resid_pdrop=0.1, resid_pdrop=0.1,
embd_pdrop=0.1, embd_pdrop=0.1,
attn_pdrop=0.1, attn_pdrop=0.1,
type_vocab_size=2,
initializer_range=0.02): initializer_range=0.02):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
@ -65,8 +63,6 @@ class OpenAIGPTConfig(object):
n_layer: Number of hidden layers in the Transformer encoder. n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
the Transformer encoder. the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
afn: The non-linear activation function (function or string) in the afn: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
resid_pdrop: The dropout probabilitiy for all fully connected resid_pdrop: The dropout probabilitiy for all fully connected
@ -74,8 +70,6 @@ class OpenAIGPTConfig(object):
attn_pdrop: The dropout ratio for the attention attn_pdrop: The dropout ratio for the attention
probabilities. probabilities.
embd_pdrop: The dropout ratio for the embeddings. embd_pdrop: The dropout ratio for the embeddings.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`OpenAIGPTModel`.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
@ -92,11 +86,9 @@ class OpenAIGPTConfig(object):
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.afn = afn self.afn = afn
self.intermediate_size = intermediate_size
self.resid_pdrop = resid_pdrop self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
@ -133,6 +125,167 @@ class OpenAIGPTConfig(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class Conv1D(nn.Module):
def __init__(self, nf, rf, nx):
super(Conv1D, self).__init__()
self.rf = rf
self.nf = nf
if rf == 1: # faster 1x1 conv
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
else: # was used to train LM
raise NotImplementedError
def forward(self, x):
if self.rf == 1:
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
else:
raise NotImplementedError
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head == 0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = cfg.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b = self.b[:, :, :w.size(-2), :w.size(-1)]
w = w * b + -1e9 * (1 - b)
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
return a
class MLP(nn.Module):
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = cfg.n_embd
self.c_fc = Conv1D(n_state, 1, nx)
self.c_proj = Conv1D(nx, 1, n_state)
self.act = ACT_FNS[cfg.afn]
self.dropout = nn.Dropout(cfg.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
class Block(nn.Module):
def __init__(self, n_ctx, cfg, scale=False):
super(Block, self).__init__()
nx = cfg.n_embd
self.attn = Attention(nx, n_ctx, cfg, scale)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4 * nx, cfg)
self.ln_2 = LayerNorm(nx)
def forward(self, x):
a = self.attn(x)
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
return h
class OpenAIGPTLMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model_embeddings_weights, cfg):
super(OpenAIGPTLMHead, self).__init__()
self.n_embd = cfg.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state)
return lm_logits
class OpenAIGPTMultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, cfg):
super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = cfg.n_embd
# self.multiple_choice_token = multiple_choice_token
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, classification_token_mask):
# Classification logits
# hidden_states = hidden_states.view(-1, self.n_embd)
# classification_token_mask = classification_token_mask.view(-1, 1).expand_as(hidden_states)
multiple_choice_h = hidden_states * classification_token_mask.unsqueeze(-1)
multiple_choice_h = multiple_choice_h.sum(dim=-2)
# flat = x[..., 0].contiguous().view(-1)
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
# multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1)
# # This double transposition is there to replicate the behavior
# # of the noise_shape argument in the tensorflow
# # implementation. For more details, see
# # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
# multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
# multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
return multiple_choice_logits
class OpenAIGPTPreTrainedModel(nn.Module): class OpenAIGPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
@ -142,7 +295,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if not isinstance(config, OpenAIGPTConfig): if not isinstance(config, OpenAIGPTConfig):
raise ValueError( raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. " "Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a Google pretrained model use " "To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) ))
@ -161,11 +314,12 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def post_loading(self): def set_num_special_tokens(self, num_special_tokens):
pass pass
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
*inputs, **kwargs):
""" """
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
@ -178,7 +332,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
. `openai_gpt_config.json` a configuration file for the model . `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
@ -263,167 +417,15 @@ class OpenAIGPTPreTrainedModel(nn.Module):
model.__class__.__name__, unexpected_keys)) model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
model.post_loading() # Add additional embeddings for special tokens if needed
if num_special_tokens != config.n_special:
model.set_num_special_tokens(num_special_tokens)
if tempdir: if tempdir:
# Clean up temp dir # Clean up temp dir
shutil.rmtree(tempdir) shutil.rmtree(tempdir)
return model return model
class Conv1D(nn.Module):
def __init__(self, nf, rf, nx):
super(Conv1D, self).__init__()
self.rf = rf
self.nf = nf
if rf == 1: # faster 1x1 conv
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
else: # was used to train LM
raise NotImplementedError
def forward(self, x):
if self.rf == 1:
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
else:
raise NotImplementedError
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head == 0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = cfg.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
return a
class MLP(nn.Module):
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = cfg.n_embd
self.c_fc = Conv1D(n_state, 1, nx)
self.c_proj = Conv1D(nx, 1, n_state)
self.act = ACT_FNS[cfg.afn]
self.dropout = nn.Dropout(cfg.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
class Block(nn.Module):
def __init__(self, n_ctx, cfg, scale=False):
super(Block, self).__init__()
nx = cfg.n_embd
self.attn = Attention(nx, n_ctx, cfg, scale)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4 * nx, cfg)
self.ln_2 = LayerNorm(nx)
def forward(self, x):
a = self.attn(x)
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
return h
class OpenAIGPTLMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model_embeddings_weights, cfg):
super(OpenAIGPTLMHead, self).__init__()
self.n_embd = cfg.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, h):
# Truncated Language modeling logits (we remove the last token)
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(h_trunc)
return lm_logits
class OpenAIGPTClfHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, clf_token, cfg):
super(OpenAIGPTClfHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, h, x):
# Classification logits
clf_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
# This double transposition is there to replicate the behavior
# of the noise_shape argument in the tensorflow
# implementation. For more details, see
# https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
clf_h = clf_h.contiguous().view(-1, self.n_embd)
clf_logits = self.linear(clf_h)
return clf_logits.view(-1, x.size(1))
class OpenAIGPTModel(OpenAIGPTPreTrainedModel): class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model """ """ OpenAI GPT model """
@ -440,6 +442,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# nn.init.normal_(self.embed.weight, std=0.02) # nn.init.normal_(self.embed.weight, std=0.02)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice "
# Update config # Update config
self.config.n_special = num_special_tokens self.config.n_special = num_special_tokens
# # Build new embeddings and initialize # # Build new embeddings and initialize
@ -451,45 +454,83 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :] self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :]
def forward(self, x): def forward(self, input_ids, position_ids=None, token_type_ids=None):
x = x.view(-1, x.size(-2), x.size(-1)) if position_ids is None:
e = self.embed(x) start = self.config.vocab_size + self.config.n_special
end = start + input_ids.size(-1)
position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.embed(input_ids)
position_embeds = self.embed(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.embed(token_type_ids)
else:
token_type_embeds = 0
# Add the position information to the input embeddings # Add the position information to the input embeddings
h = e.sum(dim=2) # h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds
for block in self.h: for block in self.h:
h = block(h) hidden_states = block(hidden_states)
return h return hidden_states.view(*input_shape, hidden_states.size(-1))
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model with language model and classification heads """ """ OpenAI GPT model with language model and classification heads """
def __init__(self, cfg, clf_token='[CLS]'): def __init__(self, cfg):
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg) super(OpenAIGPTLMHeadModel, self).__init__(cfg)
self.transformer = OpenAIGPTModel(cfg) self.transformer = OpenAIGPTModel(cfg)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg) self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
self.clf_head = OpenAIGPTClfHead(clf_token, cfg)
self.apply(self.init_weights) self.apply(self.init_weights)
def post_loading(self):
" Set the number of special tokens to 1 (for the [CLS] token) "
self.set_num_special_tokens(1)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice " " Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight) self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
def forward(self, x, lm_labels=None, clf_labels=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
h = self.transformer(x) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(h) lm_logits = self.lm_head(hidden_states)
clf_logits = self.clf_head(h, x)
losses = []
if lm_labels is not None: if lm_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(lm_logits, lm_labels)) loss = loss_fct(lm_logits, lm_labels)
if clf_labels is not None: return loss
return lm_logits
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model with language model and classification heads """
def __init__(self, cfg):
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg)
self.transformer = OpenAIGPTModel(cfg)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(cfg)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
def forward(self, input_ids, classification_token_mask, position_ids=None, token_type_ids=None,
lm_labels=None, multiple_choice_labels=None):
"""
input_ids as to be of shape B x C x S
lm_labels can be masked using the -1 value
"""
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states)
multiple_choice_logits = self.multiple_choice_head(hidden_states, classification_token_mask)
losses = []
if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
if multiple_choice_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(clf_logits, clf_labels)) losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1)))
if losses: if losses:
return losses return losses
return lm_logits, clf_logits return lm_logits, multiple_choice_logits

View File

@ -67,19 +67,17 @@ class OpenAIGPTTokenizer(object):
mostly a wrapper for a public python bpe tokenizer mostly a wrapper for a public python bpe tokenizer
""" """
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
""" """
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name] merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
vocab_file = pretrained_model_name vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
if os.path.isdir(vocab_file): merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
merges_file = os.path.join(vocab_file, MERGES_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@ -87,11 +85,12 @@ class OpenAIGPTTokenizer(object):
except FileNotFoundError: except FileNotFoundError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find files {} and {} "
"associated to this path or url.".format( "at this path or url.".format(
pretrained_model_name, pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file)) pretrained_model_name_or_path,
vocab_file, merges_file))
return None return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading vocabulary file {}".format(vocab_file))
@ -101,29 +100,38 @@ class OpenAIGPTTokenizer(object):
vocab_file, resolved_vocab_file)) vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format( logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file)) merges_file, resolved_merges_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings # than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file): def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
try: try:
import ftfy import ftfy
import spacy import spacy
except ImportError: except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.") raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
self.max_len = max_len if max_len is not None else int(1e12)
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
if not special_tokens:
self.special_tokens = {}
else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def set_special_tokens(self, special_tokens):
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',) word = tuple(token[:-1]) + ( token[-1] + '</w>',)
@ -168,20 +176,38 @@ class OpenAIGPTTokenizer(object):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, texts, verbose=True): def tokenize(self, text):
texts_tokens = [] split_tokens = []
if verbose: text = self.nlp(text_standardize(self.fix_text(text)))
for text in tqdm(texts, ncols=80, leave=False):
text = self.nlp(text_standardize(ftfy.fix_text(text)))
text_tokens = []
for token in text: for token in text:
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
texts_tokens.append(text_tokens) return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else: else:
for text in texts: ids.append(self.encoder.get(token, 0))
text = self.nlp(text_standardize(ftfy.fix_text(text))) if len(ids) > self.max_len:
text_tokens = [] raise ValueError(
for token in text: "Token indices sequence length is longer than the specified maximum "
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) " sequence length for this BERT model ({} > {}). Running this"
texts_tokens.append(text_tokens) " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
return texts_tokens )
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.decoder[i])
return tokens
def decode(self, ids):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids)
out_string = ''.join(tokens).replace('</w>', ' ')
return out_string

View File

@ -0,0 +1,192 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import json
import random
import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel)
class OpenAIGPTModelTest(unittest.TestCase):
class OpenAIGPTModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_special=1,
n_ctx=33,
n_embd=32,
n_layer=5,
n_head=4,
n_choices=3,
afn="gelu",
resid_pdrop=0.1,
attn_pdrop=0.1,
embd_pdrop=0.1,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_position_ids = use_position_ids
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_ctx = n_ctx
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.afn = afn
self.n_choices = n_choices
self.resid_pdrop = resid_pdrop
self.attn_pdrop = attn_pdrop
self.embd_pdrop = embd_pdrop
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
position_ids = None
if self.use_position_ids:
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx)
position_ids = position_ids + self.n_special + self.vocab_size
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.n_ctx + self.n_special + self.vocab_size
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
multiple_choice_labels = None
lm_labels = None
classification_token_mask = None
if self.use_labels:
multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
classification_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_ctx=self.n_ctx,
n_special=self.n_special,
n_embd=self.n_embd,
n_layer=self.n_layer,
n_head=self.n_head,
afn=self.afn,
resid_pdrop=self.resid_pdrop,
attn_pdrop=self.attn_pdrop,
embd_pdrop=self.embd_pdrop,
initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask)
def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask):
model = OpenAIGPTModel(config)
hidden_states = model(input_ids, position_ids, token_type_ids)
outputs = {
"hidden_states": hidden_states,
}
return outputs
def check_openai_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states"].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask):
model = OpenAIGPTDoubleHeadsModel(config)
loss = model(input_ids, classification_token_mask, position_ids,
token_type_ids, lm_labels, multiple_choice_labels)
lm_logits, multiple_choice_logits = model(input_ids, classification_token_mask, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"multiple_choice_logits": multiple_choice_logits,
}
return outputs
def check_openai_double_heads_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(result["multiple_choice_logits"].size()),
[self.batch_size, self.n_choices])
def check_openai_double_heads_loss_output(self, result):
self.parent.assertListEqual(
[list(l.size()) for l in result["loss"]],
[[], []])
def test_default(self):
self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self))
def test_config_to_json_string(self):
config = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37)
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_openai_model(*config_and_inputs)
tester.check_openai_model_output(output_result)
output_result = tester.create_openai_double_heads(*config_and_inputs)
tester.check_openai_double_heads_output(output_result)
tester.check_openai_double_heads_loss_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__":
unittest.main()