From 3cf12b235a032b57ea72d261d16f36b5684d754c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 8 Jan 2019 16:24:23 +0100 Subject: [PATCH] added tests + fixed losses --- pytorch_pretrained_bert/modeling.py | 2 +- pytorch_pretrained_bert/modeling_openai.py | 425 ++++++++++-------- .../tokenization_openai.py | 90 ++-- tests/modeling_openai_test.py | 192 ++++++++ 4 files changed, 484 insertions(+), 225 deletions(-) create mode 100644 tests/modeling_openai_test.py diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index d2a0cf8dd2f..021d2334ca5 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -549,7 +549,7 @@ class BertPreTrainedModel(nn.Module): model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + model.__class__.__name__, "\n\t".join(error_msgs))) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 8e8ca0db00b..9442b1ed69b 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -48,12 +48,10 @@ class OpenAIGPTConfig(object): n_embd=768, n_layer=12, n_head=12, - intermediate_size=3072, afn="gelu", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, - type_vocab_size=2, initializer_range=0.02): """Constructs OpenAIGPTConfig. @@ -65,8 +63,6 @@ class OpenAIGPTConfig(object): n_layer: Number of hidden layers in the Transformer encoder. n_head: Number of attention heads for each attention layer in the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. afn: The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported. resid_pdrop: The dropout probabilitiy for all fully connected @@ -74,8 +70,6 @@ class OpenAIGPTConfig(object): attn_pdrop: The dropout ratio for the attention probabilities. embd_pdrop: The dropout ratio for the embeddings. - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `OpenAIGPTModel`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ @@ -92,11 +86,9 @@ class OpenAIGPTConfig(object): self.n_layer = n_layer self.n_head = n_head self.afn = afn - self.intermediate_size = intermediate_size self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop - self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range else: raise ValueError("First argument must be either a vocabulary size (int)" @@ -133,6 +125,167 @@ class OpenAIGPTConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" +class Conv1D(nn.Module): + def __init__(self, nf, rf, nx): + super(Conv1D, self).__init__() + self.rf = rf + self.nf = nf + if rf == 1: # faster 1x1 conv + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = Parameter(w) + self.bias = Parameter(torch.zeros(nf)) + else: # was used to train LM + raise NotImplementedError + + def forward(self, x): + if self.rf == 1: + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + else: + raise NotImplementedError + return x + + +class Attention(nn.Module): + def __init__(self, nx, n_ctx, cfg, scale=False): + super(Attention, self).__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implem] + assert n_state % cfg.n_head == 0 + self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.n_head = cfg.n_head + self.split_size = n_state + self.scale = scale + self.c_attn = Conv1D(n_state * 3, 1, nx) + self.c_proj = Conv1D(n_state, 1, nx) + self.attn_dropout = nn.Dropout(cfg.attn_pdrop) + self.resid_dropout = nn.Dropout(cfg.resid_pdrop) + + def _attn(self, q, k, v): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + # w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights + # XD: self.b may be larger than w, so we need to crop it + b = self.b[:, :, :w.size(-2), :w.size(-1)] + w = w * b + -1e9 * (1 - b) + + w = nn.Softmax(dim=-1)(w) + w = self.attn_dropout(w) + return torch.matmul(w, v) + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) + else: + return x.permute(0, 2, 1, 3) + + def forward(self, x): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + a = self._attn(query, key, value) + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a) + return a + + +class MLP(nn.Module): + def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd) + super(MLP, self).__init__() + nx = cfg.n_embd + self.c_fc = Conv1D(n_state, 1, nx) + self.c_proj = Conv1D(nx, 1, n_state) + self.act = ACT_FNS[cfg.afn] + self.dropout = nn.Dropout(cfg.resid_pdrop) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class Block(nn.Module): + def __init__(self, n_ctx, cfg, scale=False): + super(Block, self).__init__() + nx = cfg.n_embd + self.attn = Attention(nx, n_ctx, cfg, scale) + self.ln_1 = LayerNorm(nx) + self.mlp = MLP(4 * nx, cfg) + self.ln_2 = LayerNorm(nx) + + def forward(self, x): + a = self.attn(x) + n = self.ln_1(x + a) + m = self.mlp(n) + h = self.ln_2(n + m) + return h + + +class OpenAIGPTLMHead(nn.Module): + """ Language Model Head for the transformer """ + + def __init__(self, model_embeddings_weights, cfg): + super(OpenAIGPTLMHead, self).__init__() + self.n_embd = cfg.n_embd + self.set_embeddings_weights(model_embeddings_weights) + + def set_embeddings_weights(self, model_embeddings_weights): + embed_shape = model_embeddings_weights.shape + self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) + self.decoder.weight = model_embeddings_weights # Tied weights + + def forward(self, hidden_state): + # Truncated Language modeling logits (we remove the last token) + # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) + lm_logits = self.decoder(hidden_state) + return lm_logits + + +class OpenAIGPTMultipleChoiceHead(nn.Module): + """ Classifier Head for the transformer """ + + def __init__(self, cfg): + super(OpenAIGPTMultipleChoiceHead, self).__init__() + self.n_embd = cfg.n_embd + # self.multiple_choice_token = multiple_choice_token + self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation + self.linear = nn.Linear(cfg.n_embd, 1) + + nn.init.normal_(self.linear.weight, std = 0.02) + nn.init.normal_(self.linear.bias, 0) + + def forward(self, hidden_states, classification_token_mask): + # Classification logits + # hidden_states = hidden_states.view(-1, self.n_embd) + # classification_token_mask = classification_token_mask.view(-1, 1).expand_as(hidden_states) + multiple_choice_h = hidden_states * classification_token_mask.unsqueeze(-1) + multiple_choice_h = multiple_choice_h.sum(dim=-2) + # flat = x[..., 0].contiguous().view(-1) + # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] + # multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1) + # # This double transposition is there to replicate the behavior + # # of the noise_shape argument in the tensorflow + # # implementation. For more details, see + # # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11 + # multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) + # multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd) + multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) + return multiple_choice_logits + + class OpenAIGPTPreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. @@ -142,7 +295,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): if not isinstance(config, OpenAIGPTConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. " - "To create a model from a Google pretrained model use " + "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) @@ -161,11 +314,12 @@ class OpenAIGPTPreTrainedModel(nn.Module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def post_loading(self): + def set_num_special_tokens(self, num_special_tokens): pass @classmethod - def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, + *inputs, **kwargs): """ Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -178,7 +332,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): . `openai_gpt_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ @@ -263,167 +417,15 @@ class OpenAIGPTPreTrainedModel(nn.Module): model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - model.post_loading() + model.__class__.__name__, "\n\t".join(error_msgs))) + # Add additional embeddings for special tokens if needed + if num_special_tokens != config.n_special: + model.set_num_special_tokens(num_special_tokens) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return model -class Conv1D(nn.Module): - def __init__(self, nf, rf, nx): - super(Conv1D, self).__init__() - self.rf = rf - self.nf = nf - if rf == 1: # faster 1x1 conv - w = torch.empty(nx, nf) - nn.init.normal_(w, std=0.02) - self.weight = Parameter(w) - self.bias = Parameter(torch.zeros(nf)) - else: # was used to train LM - raise NotImplementedError - - def forward(self, x): - if self.rf == 1: - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(*size_out) - else: - raise NotImplementedError - return x - - -class Attention(nn.Module): - def __init__(self, nx, n_ctx, cfg, scale=False): - super(Attention, self).__init__() - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implem] - assert n_state % cfg.n_head == 0 - self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) - self.n_head = cfg.n_head - self.split_size = n_state - self.scale = scale - self.c_attn = Conv1D(n_state * 3, 1, nx) - self.c_proj = Conv1D(n_state, 1, nx) - self.attn_dropout = nn.Dropout(cfg.attn_pdrop) - self.resid_dropout = nn.Dropout(cfg.resid_pdrop) - - def _attn(self, q, k, v): - w = torch.matmul(q, k) - if self.scale: - w = w / math.sqrt(v.size(-1)) - w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights - w = nn.Softmax(dim=-1)(w) - w = self.attn_dropout(w) - return torch.matmul(w, v) - - def merge_heads(self, x): - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states - - def split_heads(self, x, k=False): - new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) - x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states - if k: - return x.permute(0, 2, 3, 1) - else: - return x.permute(0, 2, 1, 3) - - def forward(self, x): - x = self.c_attn(x) - query, key, value = x.split(self.split_size, dim=2) - query = self.split_heads(query) - key = self.split_heads(key, k=True) - value = self.split_heads(value) - a = self._attn(query, key, value) - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a) - return a - - -class MLP(nn.Module): - def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd) - super(MLP, self).__init__() - nx = cfg.n_embd - self.c_fc = Conv1D(n_state, 1, nx) - self.c_proj = Conv1D(nx, 1, n_state) - self.act = ACT_FNS[cfg.afn] - self.dropout = nn.Dropout(cfg.resid_pdrop) - - def forward(self, x): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - return self.dropout(h2) - - -class Block(nn.Module): - def __init__(self, n_ctx, cfg, scale=False): - super(Block, self).__init__() - nx = cfg.n_embd - self.attn = Attention(nx, n_ctx, cfg, scale) - self.ln_1 = LayerNorm(nx) - self.mlp = MLP(4 * nx, cfg) - self.ln_2 = LayerNorm(nx) - - def forward(self, x): - a = self.attn(x) - n = self.ln_1(x + a) - m = self.mlp(n) - h = self.ln_2(n + m) - return h - - -class OpenAIGPTLMHead(nn.Module): - """ Language Model Head for the transformer """ - - def __init__(self, model_embeddings_weights, cfg): - super(OpenAIGPTLMHead, self).__init__() - self.n_embd = cfg.n_embd - self.set_embeddings_weights(model_embeddings_weights) - - def set_embeddings_weights(self, model_embeddings_weights): - embed_shape = model_embeddings_weights.shape - self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) - self.decoder.weight = model_embeddings_weights # Tied weights - - def forward(self, h): - # Truncated Language modeling logits (we remove the last token) - h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) - lm_logits = self.decoder(h_trunc) - return lm_logits - - -class OpenAIGPTClfHead(nn.Module): - """ Classifier Head for the transformer """ - - def __init__(self, clf_token, cfg): - super(OpenAIGPTClfHead, self).__init__() - self.n_embd = cfg.n_embd - self.clf_token = clf_token - self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation - self.linear = nn.Linear(cfg.n_embd, 1) - - nn.init.normal_(self.linear.weight, std = 0.02) - nn.init.normal_(self.linear.bias, 0) - - def forward(self, h, x): - # Classification logits - clf_h = h.view(-1, self.n_embd) - flat = x[..., 0].contiguous().view(-1) - clf_h = clf_h[flat == self.clf_token, :] - clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1) - # This double transposition is there to replicate the behavior - # of the noise_shape argument in the tensorflow - # implementation. For more details, see - # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11 - clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2) - clf_h = clf_h.contiguous().view(-1, self.n_embd) - clf_logits = self.linear(clf_h) - - return clf_logits.view(-1, x.size(1)) - class OpenAIGPTModel(OpenAIGPTPreTrainedModel): """ OpenAI GPT model """ @@ -440,6 +442,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): # nn.init.normal_(self.embed.weight, std=0.02) def set_num_special_tokens(self, num_special_tokens): + " Update input embeddings with new embedding matrice " # Update config self.config.n_special = num_special_tokens # # Build new embeddings and initialize @@ -451,45 +454,83 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :] - def forward(self, x): - x = x.view(-1, x.size(-2), x.size(-1)) - e = self.embed(x) + def forward(self, input_ids, position_ids=None, token_type_ids=None): + if position_ids is None: + start = self.config.vocab_size + self.config.n_special + end = start + input_ids.size(-1) + position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_ids.size(-1)) + position_ids = position_ids.view(-1, position_ids.size(-1)) + + inputs_embeds = self.embed(input_ids) + position_embeds = self.embed(position_ids) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + token_type_embeds = self.embed(token_type_ids) + else: + token_type_embeds = 0 # Add the position information to the input embeddings - h = e.sum(dim=2) + # h = e.sum(dim=2) + hidden_states = inputs_embeds + position_embeds + token_type_embeds for block in self.h: - h = block(h) - return h + hidden_states = block(hidden_states) + return hidden_states.view(*input_shape, hidden_states.size(-1)) - -class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): """ OpenAI GPT model with language model and classification heads """ - def __init__(self, cfg, clf_token='[CLS]'): - super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg) + def __init__(self, cfg): + super(OpenAIGPTLMHeadModel, self).__init__(cfg) self.transformer = OpenAIGPTModel(cfg) self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg) - self.clf_head = OpenAIGPTClfHead(clf_token, cfg) self.apply(self.init_weights) - def post_loading(self): - " Set the number of special tokens to 1 (for the [CLS] token) " - self.set_num_special_tokens(1) - def set_num_special_tokens(self, num_special_tokens): " Update input and output embeddings with new embedding matrice " self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.embed.weight) - def forward(self, x, lm_labels=None, clf_labels=None): - h = self.transformer(x) - lm_logits = self.lm_head(h) - clf_logits = self.clf_head(h, x) - losses = [] + def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): + hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + lm_logits = self.lm_head(hidden_states) if lm_labels is not None: loss_fct = CrossEntropyLoss() - losses.append(loss_fct(lm_logits, lm_labels)) - if clf_labels is not None: + loss = loss_fct(lm_logits, lm_labels) + return loss + return lm_logits + +class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): + """ OpenAI GPT model with language model and classification heads """ + def __init__(self, cfg): + super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg) + self.transformer = OpenAIGPTModel(cfg) + self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg) + self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(cfg) + self.apply(self.init_weights) + + def set_num_special_tokens(self, num_special_tokens): + " Update input and output embeddings with new embedding matrice " + self.transformer.set_num_special_tokens(num_special_tokens) + self.lm_head.set_embeddings_weights(self.transformer.embed.weight) + + def forward(self, input_ids, classification_token_mask, position_ids=None, token_type_ids=None, + lm_labels=None, multiple_choice_labels=None): + """ + input_ids as to be of shape B x C x S + lm_labels can be masked using the -1 value + """ + hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + lm_logits = self.lm_head(hidden_states) + multiple_choice_logits = self.multiple_choice_head(hidden_states, classification_token_mask) + losses = [] + if lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))) + if multiple_choice_labels is not None: loss_fct = CrossEntropyLoss() - losses.append(loss_fct(clf_logits, clf_labels)) + losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1))) if losses: return losses - return lm_logits, clf_logits + return lm_logits, multiple_choice_logits diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index dd0df83e93f..1492075817b 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -67,19 +67,17 @@ class OpenAIGPTTokenizer(object): mostly a wrapper for a public python bpe tokenizer """ @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ - if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name] + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] else: - vocab_file = pretrained_model_name - if os.path.isdir(vocab_file): - vocab_file = os.path.join(vocab_file, VOCAB_NAME) - merges_file = os.path.join(vocab_file, MERGES_NAME) + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -87,11 +85,12 @@ class OpenAIGPTTokenizer(object): except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name, + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - vocab_file)) + pretrained_model_name_or_path, + vocab_file, merges_file)) return None if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: logger.info("loading vocabulary file {}".format(vocab_file)) @@ -101,29 +100,38 @@ class OpenAIGPTTokenizer(object): vocab_file, resolved_vocab_file)) logger.info("loading merges file {} from cache at {}".format( merges_file, resolved_merges_file)) - if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) return tokenizer - def __init__(self, vocab_file, merges_file): + def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): try: import ftfy import spacy except ImportError: raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.") + self.max_len = max_len if max_len is not None else int(1e12) self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) + self.fix_text = ftfy.fix_text self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = [tuple(merge.split()) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + if not special_tokens: + self.special_tokens = {} + else: + self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + + def set_special_tokens(self, special_tokens): + self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) def bpe(self, token): word = tuple(token[:-1]) + ( token[-1] + '',) @@ -168,20 +176,38 @@ class OpenAIGPTTokenizer(object): self.cache[token] = word return word - def tokenize(self, texts, verbose=True): - texts_tokens = [] - if verbose: - for text in tqdm(texts, ncols=80, leave=False): - text = self.nlp(text_standardize(ftfy.fix_text(text))) - text_tokens = [] - for token in text: - text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) - texts_tokens.append(text_tokens) - else: - for text in texts: - text = self.nlp(text_standardize(ftfy.fix_text(text))) - text_tokens = [] - for token in text: - text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) - texts_tokens.append(text_tokens) - return texts_tokens + def tokenize(self, text): + split_tokens = [] + text = self.nlp(text_standardize(self.fix_text(text))) + for token in text: + split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + raise ValueError( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.decoder[i]) + return tokens + + def decode(self, ids): + """Converts a sequence of ids in a string.""" + tokens = self.convert_ids_to_tokens(ids) + out_string = ''.join(tokens).replace('', ' ') + return out_string diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py new file mode 100644 index 00000000000..539fbda9e4e --- /dev/null +++ b/tests/modeling_openai_test.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import json +import random + +import torch + +from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel) + + +class OpenAIGPTModelTest(unittest.TestCase): + class OpenAIGPTModelTester(object): + + def __init__(self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_position_ids=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + n_special=1, + n_ctx=33, + n_embd=32, + n_layer=5, + n_head=4, + n_choices=3, + afn="gelu", + resid_pdrop=0.1, + attn_pdrop=0.1, + embd_pdrop=0.1, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + scope=None): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_position_ids = use_position_ids + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.n_special = n_special + self.n_ctx = n_ctx + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.afn = afn + self.n_choices = n_choices + self.resid_pdrop = resid_pdrop + self.attn_pdrop = attn_pdrop + self.embd_pdrop = embd_pdrop + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) + + position_ids = None + if self.use_position_ids: + position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx) + position_ids = position_ids + self.n_special + self.vocab_size + + token_type_ids = None + if self.use_token_type_ids: + total_voc = self.n_ctx + self.n_special + self.vocab_size + token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc) + + multiple_choice_labels = None + lm_labels = None + classification_token_mask = None + if self.use_labels: + multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) + lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) + classification_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() + + config = OpenAIGPTConfig( + vocab_size_or_config_json_file=self.vocab_size, + n_ctx=self.n_ctx, + n_special=self.n_special, + n_embd=self.n_embd, + n_layer=self.n_layer, + n_head=self.n_head, + afn=self.afn, + resid_pdrop=self.resid_pdrop, + attn_pdrop=self.attn_pdrop, + embd_pdrop=self.embd_pdrop, + initializer_range=self.initializer_range) + + return (config, input_ids, token_type_ids, position_ids, + multiple_choice_labels, lm_labels, classification_token_mask) + + def create_openai_model(self, config, input_ids, token_type_ids, position_ids, + multiple_choice_labels, lm_labels, classification_token_mask): + model = OpenAIGPTModel(config) + hidden_states = model(input_ids, position_ids, token_type_ids) + outputs = { + "hidden_states": hidden_states, + } + return outputs + + def check_openai_model_output(self, result): + self.parent.assertListEqual( + list(result["hidden_states"].size()), + [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) + + + def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, + multiple_choice_labels, lm_labels, classification_token_mask): + model = OpenAIGPTDoubleHeadsModel(config) + loss = model(input_ids, classification_token_mask, position_ids, + token_type_ids, lm_labels, multiple_choice_labels) + lm_logits, multiple_choice_logits = model(input_ids, classification_token_mask, position_ids, token_type_ids) + outputs = { + "loss": loss, + "lm_logits": lm_logits, + "multiple_choice_logits": multiple_choice_logits, + } + return outputs + + def check_openai_double_heads_output(self, result): + total_voc = self.n_ctx + self.n_special + self.vocab_size + self.parent.assertListEqual( + list(result["lm_logits"].size()), + [self.batch_size, self.n_choices, self.seq_length, total_voc]) + self.parent.assertListEqual( + list(result["multiple_choice_logits"].size()), + [self.batch_size, self.n_choices]) + + def check_openai_double_heads_loss_output(self, result): + self.parent.assertListEqual( + [list(l.size()) for l in result["loss"]], + [[], []]) + + def test_default(self): + self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self)) + + def test_config_to_json_string(self): + config = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37) + obj = json.loads(config.to_json_string()) + self.assertEqual(obj["vocab_size"], 99) + self.assertEqual(obj["n_embd"], 37) + + def run_tester(self, tester): + config_and_inputs = tester.prepare_config_and_inputs() + output_result = tester.create_openai_model(*config_and_inputs) + tester.check_openai_model_output(output_result) + + output_result = tester.create_openai_double_heads(*config_and_inputs) + tester.check_openai_double_heads_output(output_result) + tester.check_openai_double_heads_loss_output(output_result) + + @classmethod + def ids_tensor(cls, shape, vocab_size, rng=None, name=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() + + +if __name__ == "__main__": + unittest.main()