add OpenAI GPT

This commit is contained in:
thomwolf 2019-01-08 12:26:58 +01:00
parent 793dcd236b
commit eed51c5bdf
8 changed files with 573 additions and 270 deletions

View File

@ -1,8 +1,11 @@
__version__ = "0.4.0"
__version__ = "0.5.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering)
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE

View File

@ -1,22 +1,40 @@
# coding: utf8
def main():
import sys
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint"
]:
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else:
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
OPENAI_GPT_CONFIG = sys.argv[4]
else:
OPENAI_GPT_CONFIG = ""
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__':
main()

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
"""Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import
from __future__ import division
@ -20,45 +20,53 @@ from __future__ import print_function
import os
import re
import json
import argparse
import tensorflow as tf
import torch
import numpy as np
from .modeling import BertConfig, BertForPreTraining
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
def convert_openai_checkpoint_to_pytorch(open_checkpoint_folder_path, openai_config_file, pytorch_dump_path):
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
path_names='./'):
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Load weights from TF model
print("Loading weights...")
names = json.load(open(path_names + 'parameters_names.json'))
shapes = json.load(open(path + 'params_shapes.json'))
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
if n_ctx > 0:
init_params[0] = init_params[0][:n_ctx]
if n_special > 0:
init_params[0] = np.concatenate(
[init_params[1],
(np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
init_params[0]
], 0)
else:
init_params[0] = np.concatenate(
[init_params[1],
init_params[0]
], 0)
# if n_ctx > 0:
# init_params[0] = init_params[0][:n_ctx]
# if n_special > 0:
# init_params[0] = np.concatenate(
# [init_params[1],
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
# init_params[0]
# ], 0)
# else:
# init_params[0] = np.concatenate(
# [init_params[1],
# init_params[0]
# ], 0)
# del init_params[1]
# if n_transfer == -1:
# n_transfer = 0
# else:
# n_transfer = 1 + n_transfer * 12
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
if n_transfer == -1:
n_transfer = 0
else:
n_transfer = 1 + n_transfer * 12
init_params = [arr.squeeze() for arr in init_params]
# Construct model
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
model = OpenAIGPTModel(config)
try:
assert model.embed.weight.shape == init_params[0].shape
except AssertionError as e:
@ -66,8 +74,10 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
init_params.pop(0)
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
name = name[:-2]
@ -78,64 +88,22 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == ip.shape
except AssertionError as e:
e.args += (pointer.shape, ip.shape)
raise
pointer.data = torch.from_numpy(ip)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
if l[0] == 'g':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
elif l[0] == 'b':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
elif l[0] == 'w':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert pointer.shape == array.shape
except AssertionError as e:
@ -145,30 +113,33 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
pointer.data = torch.from_numpy(array)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
parser.add_argument("--openai_checkpoint_folder_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--openai_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.bert_config_file,
args.pytorch_dump_path)
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
args.pytorch_dump_folder_path,
args.openai_config_file)

View File

@ -416,12 +416,12 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score
class PreTrainedModel(nn.Module):
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
@ -447,7 +447,7 @@ class PreTrainedModel(nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
@ -547,13 +547,16 @@ class PreTrainedModel(nn.Module):
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedModel):
class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
@ -636,7 +639,7 @@ class BertModel(PreTrainedModel):
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedModel):
class BertForPreTraining(BertPreTrainedModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
@ -707,7 +710,7 @@ class BertForPreTraining(PreTrainedModel):
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
@ -768,7 +771,7 @@ class BertForMaskedLM(PreTrainedModel):
return prediction_scores
class BertForNextSentencePrediction(PreTrainedModel):
class BertForNextSentencePrediction(BertPreTrainedModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
@ -830,7 +833,7 @@ class BertForNextSentencePrediction(PreTrainedModel):
return seq_relationship_score
class BertForSequenceClassification(PreTrainedModel):
class BertForSequenceClassification(BertPreTrainedModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
@ -875,7 +878,7 @@ class BertForSequenceClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
@ -896,7 +899,7 @@ class BertForSequenceClassification(PreTrainedModel):
return logits
class BertForMultipleChoice(PreTrainedModel):
class BertForMultipleChoice(BertPreTrainedModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
@ -940,7 +943,7 @@ class BertForMultipleChoice(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
def __init__(self, config, num_choices):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
@ -965,7 +968,7 @@ class BertForMultipleChoice(PreTrainedModel):
return reshaped_logits
class BertForTokenClassification(PreTrainedModel):
class BertForTokenClassification(BertPreTrainedModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
@ -1010,7 +1013,7 @@ class BertForTokenClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
@ -1031,7 +1034,7 @@ class BertForTokenClassification(PreTrainedModel):
return logits
class BertForQuestionAnswering(PreTrainedModel):
class BertForQuestionAnswering(BertPreTrainedModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits

View File

@ -1,16 +1,28 @@
import os
import copy
import json
import math
import re
import logging
import tarfile
import tempfile
import shutil
import collections
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz",
}
CONFIG_NAME = 'openai_gpt_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
@ -26,6 +38,237 @@ ACT_FNS = {
'gelu': gelu
}
class OpenAIGPTConfig(object):
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""
def __init__(self,
vocab_size_or_config_json_file=40478,
n_special=0,
n_ctx=512,
n_embd=768,
n_layer=12,
n_head=12,
intermediate_size=3072,
afn="gelu",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs OpenAIGPTConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_ctx: Number of positional embeddings.
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
afn: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`OpenAIGPTModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special
self.n_ctx = n_ctx
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.afn = afn
self.intermediate_size = intermediate_size
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@property
def total_num_embeddings(self):
return self.vocab_size + self.n_special + self.n_ctx
@classmethod
def from_dict(cls, json_object):
"""Constructs a `OpenAIGPTConfig` from a Python dictionary of parameters."""
config = OpenAIGPTConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class OpenAIGPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__()
if not isinstance(config, OpenAIGPTConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def post_loading(self):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = OpenAIGPTConfig.from_json_file(config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
model.post_loading()
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class Conv1D(nn.Module):
def __init__(self, nf, rf, nx):
@ -35,15 +278,15 @@ class Conv1D(nn.Module):
if rf == 1: # faster 1x1 conv
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.w = Parameter(w)
self.b = Parameter(torch.zeros(nf))
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
else: # was used to train LM
raise NotImplementedError
def forward(self, x):
if self.rf == 1:
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
else:
raise NotImplementedError
@ -132,38 +375,18 @@ class Block(nn.Module):
return h
class TransformerModel(nn.Module):
""" Transformer model """
def __init__(self, cfg, vocab=40990, n_ctx=512):
super(TransformerModel, self).__init__()
self.vocab = vocab
self.embed = nn.Embedding(vocab, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
nn.init.normal_(self.embed.weight, std=0.02)
def forward(self, x):
x = x.view(-1, x.size(-2), x.size(-1))
e = self.embed(x)
# Add the position information to the input embeddings
h = e.sum(dim=2)
for block in self.h:
h = block(h)
return h
class LMHead(nn.Module):
class OpenAIGPTLMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model, cfg):
super(LMHead, self).__init__()
def __init__(self, model_embeddings_weights, cfg):
super(OpenAIGPTLMHead, self).__init__()
self.n_embd = cfg.n_embd
embed_shape = model.embed.weight.shape
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model.embed.weight # Tied weights
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, h):
# Truncated Language modeling logits (we remove the last token)
@ -172,14 +395,14 @@ class LMHead(nn.Module):
return lm_logits
class MultipleChoiceHead(nn.Module):
class OpenAIGPTClfHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, clf_token, cfg):
super(MultipleChoiceHead, self).__init__()
super(OpenAIGPTClfHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
@ -202,101 +425,71 @@ class MultipleChoiceHead(nn.Module):
return clf_logits.view(-1, x.size(1))
class ClfHead(nn.Module):
"""Classification Head for the transformer
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model """
TODO: test this class."""
def __init__(self, clf_token, cfg, n_class):
super(ClfHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, n_class)
def __init__(self, cfg):
super(OpenAIGPTModel, self).__init__(cfg)
total_embeddings_size = cfg.vocab_size + cfg.n_special + cfg.n_ctx
self.embed = nn.Embedding(total_embeddings_size, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(cfg.n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
self.apply(self.init_weights)
# nn.init.normal_(self.embed.weight, std=0.02)
def forward(self, h, x):
clf_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = self.dropout(clf_h)
clf_logits = self.linear(clf_h)
return clf_logits
class SimilarityHead(nn.Module):
""" Similarity Head for the transformer
TODO: test this class."""
def __init__(self, clf_token, cfg):
super(SimilarityHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, h, x):
sim_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
sim_h = sim_h[flat == self.clf_token, :]
sim_h = self.dropout(sim_h)
sim_h = sim_h.sum(dim = 1)
sim_logits = self.linear(sim_h)
return sim_logits
class DoubleHeadModel(nn.Module):
""" Transformer with language model and task specific heads """
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
if isinstance(task_head_type, str):
if task_head_type == 'multiple_choice':
self.task_head = MultipleChoiceHead(clf_token, cfg)
elif task_head_type == 'similarity':
self.task_head = SimilarityHead(clf_token, cfg)
elif task_head_type == 'inference':
# the three classes correspond to entailment, contradiction and neutral.
self.task_head = ClfHead(clf_token, cfg, 3)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
task_head_type[0] == 'classification':
n_class = task_head_type[1]
self.task_head = ClfHead(clf_token, cfg, n_class)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")
def set_num_special_tokens(self, num_special_tokens):
# Update config
self.config.n_special = num_special_tokens
# # Build new embeddings and initialize
old_embed = self.embed
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.embed)
# Copy word and positional embeddings from the previous weights
self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :]
def forward(self, x):
x = x.view(-1, x.size(-2), x.size(-1))
e = self.embed(x)
# Add the position information to the input embeddings
h = e.sum(dim=2)
for block in self.h:
h = block(h)
return h
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model with language model and classification heads """
def __init__(self, cfg, clf_token='[CLS]'):
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg)
self.transformer = OpenAIGPTModel(cfg)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
self.clf_head = OpenAIGPTClfHead(clf_token, cfg)
self.apply(self.init_weights)
def post_loading(self):
" Set the number of special tokens to 1 (for the [CLS] token) "
self.set_num_special_tokens(1)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
def forward(self, x, lm_labels=None, clf_labels=None):
h = self.transformer(x)
lm_logits = self.lm_head(h)
task_logits = self.task_head(h, x)
return lm_logits, task_logits
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
DEFAULT_CONFIG = dotdict({
'n_embd': 768,
'n_head': 12,
'n_layer': 12,
'embd_pdrop': 0.1,
'attn_pdrop': 0.1,
'resid_pdrop': 0.1,
'afn': 'gelu',
'clf_pdrop': 0.1})
clf_logits = self.clf_head(h, x)
losses = []
if lm_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(lm_logits, lm_labels))
if clf_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(clf_logits, clf_labels))
if losses:
return losses
return lm_logits, clf_logits

View File

@ -1,6 +1,23 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
def warmup_cosine(x, warmup=0.002):
@ -25,26 +42,41 @@ SCHEDULES = {
class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix.
"""
def __init__(self, params, lr, schedule, warmup, t_total,
b1=0.9, b2=0.999, e=1e-8, l2=0,
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
vector_l2=False, max_grad_norm=-1, **kwargs):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0 <= warmup:
raise ValueError("Invalid warmup: {}".format(warmup))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2))
if not 0.0 <= e:
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
@ -91,14 +123,18 @@ class OpenAIAdam(Optimizer):
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Add weight decay at the end (fixed version)
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
p.data.add_(-lr_scheduled * group['l2'], p.data)
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
return loss

View File

@ -1,9 +1,39 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
import os
import re
import ftfy
import json
import spacy
from tqdm import tqdm
import logging
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'openai-gpt': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
def get_pairs(word):
"""
@ -32,16 +62,65 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class TextEncoder(object):
class OpenAIGPTTokenizer(object):
"""
mostly a wrapper for a public python bpe tokenizer
"""
@classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
merges_file = os.path.join(vocab_file, MERGES_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file):
try:
import ftfy
import spacy
except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
def __init__(self, encoder_path, bpe_path):
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.encoder = json.load(open(encoder_path))
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
@ -89,7 +168,7 @@ class TextEncoder(object):
self.cache[token] = word
return word
def encode(self, texts, verbose=True):
def tokenize(self, texts, verbose=True):
texts_tokens = []
if verbose:
for text in tqdm(texts, ncols=80, leave=False):

View File

@ -37,8 +37,8 @@ from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.4.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
version="0.5.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
long_description=open("README.md", "r", encoding='utf-8').read(),
@ -55,7 +55,7 @@ setup(
'tqdm'],
entry_points={
'console_scripts': [
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
]
},
python_requires='>=3.5.0',