mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add OpenAI GPT
This commit is contained in:
parent
793dcd236b
commit
eed51c5bdf
@ -1,8 +1,11 @@
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.5.0"
|
||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
BertForTokenClassification, BertForQuestionAnswering)
|
||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
@ -1,22 +1,40 @@
|
||||
# coding: utf8
|
||||
def main():
|
||||
import sys
|
||||
try:
|
||||
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
except ModuleNotFoundError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
|
||||
if len(sys.argv) != 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
|
||||
"convert_tf_checkpoint_to_pytorch",
|
||||
"convert_openai_checkpoint"
|
||||
]:
|
||||
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
|
||||
else:
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
||||
TF_CONFIG = sys.argv.pop()
|
||||
TF_CHECKPOINT = sys.argv.pop()
|
||||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||
try:
|
||||
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
except ModuleNotFoundError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
|
||||
if len(sys.argv) != 5:
|
||||
# pylint: disable=line-too-long
|
||||
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||
else:
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
||||
TF_CONFIG = sys.argv.pop()
|
||||
TF_CHECKPOINT = sys.argv.pop()
|
||||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
else:
|
||||
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
||||
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
if len(sys.argv) == 5:
|
||||
OPENAI_GPT_CONFIG = sys.argv[4]
|
||||
else:
|
||||
OPENAI_GPT_CONFIG = ""
|
||||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
||||
OPENAI_GPT_CONFIG,
|
||||
PYTORCH_DUMP_OUTPUT)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
"""Convert OpenAI GPT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -20,45 +20,53 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .modeling import BertConfig, BertForPreTraining
|
||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
def convert_openai_checkpoint_to_pytorch(open_checkpoint_folder_path, openai_config_file, pytorch_dump_path):
|
||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
|
||||
path_names='./'):
|
||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||
# Load weights from TF model
|
||||
print("Loading weights...")
|
||||
names = json.load(open(path_names + 'parameters_names.json'))
|
||||
shapes = json.load(open(path + 'params_shapes.json'))
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
if n_ctx > 0:
|
||||
init_params[0] = init_params[0][:n_ctx]
|
||||
if n_special > 0:
|
||||
init_params[0] = np.concatenate(
|
||||
[init_params[1],
|
||||
(np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
|
||||
init_params[0]
|
||||
], 0)
|
||||
else:
|
||||
init_params[0] = np.concatenate(
|
||||
[init_params[1],
|
||||
init_params[0]
|
||||
], 0)
|
||||
# if n_ctx > 0:
|
||||
# init_params[0] = init_params[0][:n_ctx]
|
||||
# if n_special > 0:
|
||||
# init_params[0] = np.concatenate(
|
||||
# [init_params[1],
|
||||
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
|
||||
# init_params[0]
|
||||
# ], 0)
|
||||
# else:
|
||||
# init_params[0] = np.concatenate(
|
||||
# [init_params[1],
|
||||
# init_params[0]
|
||||
# ], 0)
|
||||
# del init_params[1]
|
||||
# if n_transfer == -1:
|
||||
# n_transfer = 0
|
||||
# else:
|
||||
# n_transfer = 1 + n_transfer * 12
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
if n_transfer == -1:
|
||||
n_transfer = 0
|
||||
else:
|
||||
n_transfer = 1 + n_transfer * 12
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
# Construct model
|
||||
if openai_config_file == "":
|
||||
config = OpenAIGPTConfig()
|
||||
else:
|
||||
config = OpenAIGPTConfig(openai_config_file)
|
||||
model = OpenAIGPTModel(config)
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
@ -66,8 +74,10 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
@ -78,64 +88,22 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == ip.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, ip.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(ip)
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
config_path = os.path.abspath(bert_config_file)
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
if l[0] == 'g':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
elif l[0] == 'b':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
elif l[0] == 'w':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
@ -145,30 +113,33 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
parser.add_argument("--openai_checkpoint_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--openai_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.bert_config_file,
|
||||
args.pytorch_dump_path)
|
||||
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.openai_config_file)
|
||||
|
@ -416,12 +416,12 @@ class BertPreTrainingHeads(nn.Module):
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
super(BertPreTrainedModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
@ -447,7 +447,7 @@ class PreTrainedModel(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
@ -547,13 +547,16 @@ class PreTrainedModel(nn.Module):
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
|
||||
class BertModel(PreTrainedModel):
|
||||
class BertModel(BertPreTrainedModel):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
@ -636,7 +639,7 @@ class BertModel(PreTrainedModel):
|
||||
return encoded_layers, pooled_output
|
||||
|
||||
|
||||
class BertForPreTraining(PreTrainedModel):
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
"""BERT model with pre-training heads.
|
||||
This module comprises the BERT model followed by the two pre-training heads:
|
||||
- the masked language modeling head, and
|
||||
@ -707,7 +710,7 @@ class BertForPreTraining(PreTrainedModel):
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertForMaskedLM(PreTrainedModel):
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
"""BERT model with the masked language modeling head.
|
||||
This module comprises the BERT model followed by the masked language modeling head.
|
||||
|
||||
@ -768,7 +771,7 @@ class BertForMaskedLM(PreTrainedModel):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertForNextSentencePrediction(PreTrainedModel):
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
"""BERT model with next sentence prediction head.
|
||||
This module comprises the BERT model followed by the next sentence classification head.
|
||||
|
||||
@ -830,7 +833,7 @@ class BertForNextSentencePrediction(PreTrainedModel):
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class BertForSequenceClassification(PreTrainedModel):
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
@ -875,7 +878,7 @@ class BertForSequenceClassification(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
@ -896,7 +899,7 @@ class BertForSequenceClassification(PreTrainedModel):
|
||||
return logits
|
||||
|
||||
|
||||
class BertForMultipleChoice(PreTrainedModel):
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
"""BERT model for multiple choice tasks.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
@ -940,7 +943,7 @@ class BertForMultipleChoice(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices=2):
|
||||
def __init__(self, config, num_choices):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel(config)
|
||||
@ -965,7 +968,7 @@ class BertForMultipleChoice(PreTrainedModel):
|
||||
return reshaped_logits
|
||||
|
||||
|
||||
class BertForTokenClassification(PreTrainedModel):
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
"""BERT model for token-level classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the full hidden state of the last layer.
|
||||
@ -1010,7 +1013,7 @@ class BertForTokenClassification(PreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
@ -1031,7 +1034,7 @@ class BertForTokenClassification(PreTrainedModel):
|
||||
return logits
|
||||
|
||||
|
||||
class BertForQuestionAnswering(PreTrainedModel):
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
|
@ -1,16 +1,28 @@
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'openai_gpt_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
@ -26,6 +38,237 @@ ACT_FNS = {
|
||||
'gelu': gelu
|
||||
}
|
||||
|
||||
class OpenAIGPTConfig(object):
|
||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=40478,
|
||||
n_special=0,
|
||||
n_ctx=512,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
intermediate_size=3072,
|
||||
afn="gelu",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs OpenAIGPTConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
|
||||
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
|
||||
n_ctx: Number of positional embeddings.
|
||||
n_embd: Dimensionality of the embeddings and hidden states.
|
||||
n_layer: Number of hidden layers in the Transformer encoder.
|
||||
n_head: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
afn: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
resid_pdrop: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attn_pdrop: The dropout ratio for the attention
|
||||
probabilities.
|
||||
embd_pdrop: The dropout ratio for the embeddings.
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`OpenAIGPTModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.n_special = n_special
|
||||
self.n_ctx = n_ctx
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.afn = afn
|
||||
self.intermediate_size = intermediate_size
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@property
|
||||
def total_num_embeddings(self):
|
||||
return self.vocab_size + self.n_special + self.n_ctx
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `OpenAIGPTConfig` from a Python dictionary of parameters."""
|
||||
config = OpenAIGPTConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(OpenAIGPTPreTrainedModel, self).__init__()
|
||||
if not isinstance(config, OpenAIGPTConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def post_loading(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `openai-gpt`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `openai_gpt_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = OpenAIGPTConfig.from_json_file(config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
model.post_loading()
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, rf, nx):
|
||||
@ -35,15 +278,15 @@ class Conv1D(nn.Module):
|
||||
if rf == 1: # faster 1x1 conv
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.w = Parameter(w)
|
||||
self.b = Parameter(torch.zeros(nf))
|
||||
self.weight = Parameter(w)
|
||||
self.bias = Parameter(torch.zeros(nf))
|
||||
else: # was used to train LM
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x):
|
||||
if self.rf == 1:
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(*size_out)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@ -132,38 +375,18 @@ class Block(nn.Module):
|
||||
return h
|
||||
|
||||
|
||||
class TransformerModel(nn.Module):
|
||||
""" Transformer model """
|
||||
|
||||
def __init__(self, cfg, vocab=40990, n_ctx=512):
|
||||
super(TransformerModel, self).__init__()
|
||||
self.vocab = vocab
|
||||
self.embed = nn.Embedding(vocab, cfg.n_embd)
|
||||
self.drop = nn.Dropout(cfg.embd_pdrop)
|
||||
block = Block(n_ctx, cfg, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
||||
|
||||
nn.init.normal_(self.embed.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, x.size(-2), x.size(-1))
|
||||
e = self.embed(x)
|
||||
# Add the position information to the input embeddings
|
||||
h = e.sum(dim=2)
|
||||
for block in self.h:
|
||||
h = block(h)
|
||||
return h
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
class OpenAIGPTLMHead(nn.Module):
|
||||
""" Language Model Head for the transformer """
|
||||
|
||||
def __init__(self, model, cfg):
|
||||
super(LMHead, self).__init__()
|
||||
def __init__(self, model_embeddings_weights, cfg):
|
||||
super(OpenAIGPTLMHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
embed_shape = model.embed.weight.shape
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights):
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.decoder.weight = model.embed.weight # Tied weights
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, h):
|
||||
# Truncated Language modeling logits (we remove the last token)
|
||||
@ -172,14 +395,14 @@ class LMHead(nn.Module):
|
||||
return lm_logits
|
||||
|
||||
|
||||
class MultipleChoiceHead(nn.Module):
|
||||
class OpenAIGPTClfHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(MultipleChoiceHead, self).__init__()
|
||||
super(OpenAIGPTClfHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
@ -202,101 +425,71 @@ class MultipleChoiceHead(nn.Module):
|
||||
return clf_logits.view(-1, x.size(1))
|
||||
|
||||
|
||||
class ClfHead(nn.Module):
|
||||
"""Classification Head for the transformer
|
||||
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
""" OpenAI GPT model """
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg, n_class):
|
||||
super(ClfHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, n_class)
|
||||
def __init__(self, cfg):
|
||||
super(OpenAIGPTModel, self).__init__(cfg)
|
||||
total_embeddings_size = cfg.vocab_size + cfg.n_special + cfg.n_ctx
|
||||
self.embed = nn.Embedding(total_embeddings_size, cfg.n_embd)
|
||||
self.drop = nn.Dropout(cfg.embd_pdrop)
|
||||
block = Block(cfg.n_ctx, cfg, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
self.apply(self.init_weights)
|
||||
# nn.init.normal_(self.embed.weight, std=0.02)
|
||||
|
||||
def forward(self, h, x):
|
||||
clf_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
clf_h = clf_h[flat == self.clf_token, :]
|
||||
clf_h = self.dropout(clf_h)
|
||||
clf_logits = self.linear(clf_h)
|
||||
|
||||
return clf_logits
|
||||
|
||||
class SimilarityHead(nn.Module):
|
||||
""" Similarity Head for the transformer
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(SimilarityHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
sim_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
sim_h = sim_h[flat == self.clf_token, :]
|
||||
sim_h = self.dropout(sim_h)
|
||||
sim_h = sim_h.sum(dim = 1)
|
||||
sim_logits = self.linear(sim_h)
|
||||
|
||||
return sim_logits
|
||||
|
||||
class DoubleHeadModel(nn.Module):
|
||||
""" Transformer with language model and task specific heads """
|
||||
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
|
||||
super(DoubleHeadModel, self).__init__()
|
||||
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
||||
self.lm_head = LMHead(self.transformer, cfg)
|
||||
if isinstance(task_head_type, str):
|
||||
if task_head_type == 'multiple_choice':
|
||||
self.task_head = MultipleChoiceHead(clf_token, cfg)
|
||||
elif task_head_type == 'similarity':
|
||||
self.task_head = SimilarityHead(clf_token, cfg)
|
||||
elif task_head_type == 'inference':
|
||||
# the three classes correspond to entailment, contradiction and neutral.
|
||||
self.task_head = ClfHead(clf_token, cfg, 3)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
f"got {task_head_type}.")
|
||||
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
|
||||
task_head_type[0] == 'classification':
|
||||
n_class = task_head_type[1]
|
||||
self.task_head = ClfHead(clf_token, cfg, n_class)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
f"got {task_head_type}.")
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
# Update config
|
||||
self.config.n_special = num_special_tokens
|
||||
# # Build new embeddings and initialize
|
||||
old_embed = self.embed
|
||||
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
|
||||
# Initialize all new embeddings (in particular the special tokens)
|
||||
self.init_weights(self.embed)
|
||||
# Copy word and positional embeddings from the previous weights
|
||||
self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||
self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :]
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, x.size(-2), x.size(-1))
|
||||
e = self.embed(x)
|
||||
# Add the position information to the input embeddings
|
||||
h = e.sum(dim=2)
|
||||
for block in self.h:
|
||||
h = block(h)
|
||||
return h
|
||||
|
||||
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
""" OpenAI GPT model with language model and classification heads """
|
||||
def __init__(self, cfg, clf_token='[CLS]'):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg)
|
||||
self.transformer = OpenAIGPTModel(cfg)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
|
||||
self.clf_head = OpenAIGPTClfHead(clf_token, cfg)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def post_loading(self):
|
||||
" Set the number of special tokens to 1 (for the [CLS] token) "
|
||||
self.set_num_special_tokens(1)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
" Update input and output embeddings with new embedding matrice "
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
|
||||
def forward(self, x, lm_labels=None, clf_labels=None):
|
||||
h = self.transformer(x)
|
||||
lm_logits = self.lm_head(h)
|
||||
task_logits = self.task_head(h, x)
|
||||
|
||||
return lm_logits, task_logits
|
||||
|
||||
|
||||
class dotdict(dict):
|
||||
"""dot.notation access to dictionary attributes"""
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
DEFAULT_CONFIG = dotdict({
|
||||
'n_embd': 768,
|
||||
'n_head': 12,
|
||||
'n_layer': 12,
|
||||
'embd_pdrop': 0.1,
|
||||
'attn_pdrop': 0.1,
|
||||
'resid_pdrop': 0.1,
|
||||
'afn': 'gelu',
|
||||
'clf_pdrop': 0.1})
|
||||
clf_logits = self.clf_head(h, x)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(lm_logits, lm_labels))
|
||||
if clf_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(clf_logits, clf_labels))
|
||||
if losses:
|
||||
return losses
|
||||
return lm_logits, clf_logits
|
||||
|
@ -1,6 +1,23 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for OpenAI GPT model."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
@ -25,26 +42,41 @@ SCHEDULES = {
|
||||
class OpenAIAdam(Optimizer):
|
||||
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
||||
"""
|
||||
def __init__(self, params, lr, schedule, warmup, t_total,
|
||||
b1=0.9, b2=0.999, e=1e-8, l2=0,
|
||||
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
|
||||
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
|
||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0 <= warmup:
|
||||
raise ValueError("Invalid warmup: {}".format(warmup))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {}".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {}".format(b2))
|
||||
if not 0.0 <= e:
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {}".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(OpenAIAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -91,14 +123,18 @@ class OpenAIAdam(Optimizer):
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
# Add weight decay at the end (fixed version)
|
||||
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
|
||||
p.data.add_(-lr_scheduled * group['l2'], p.data)
|
||||
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
|
||||
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
|
||||
|
||||
return loss
|
||||
|
@ -1,9 +1,39 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
import os
|
||||
import re
|
||||
import ftfy
|
||||
import json
|
||||
import spacy
|
||||
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'openai-gpt': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
@ -32,16 +62,65 @@ def text_standardize(text):
|
||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
class TextEncoder(object):
|
||||
class OpenAIGPTTokenizer(object):
|
||||
"""
|
||||
mostly a wrapper for a public python bpe tokenizer
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
vocab_file = pretrained_model_name
|
||||
if os.path.isdir(vocab_file):
|
||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||
merges_file = os.path.join(vocab_file, MERGES_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file):
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
except ImportError:
|
||||
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
|
||||
|
||||
def __init__(self, encoder_path, bpe_path):
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.encoder = json.load(open(encoder_path))
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
@ -89,7 +168,7 @@ class TextEncoder(object):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, texts, verbose=True):
|
||||
def tokenize(self, texts, verbose=True):
|
||||
texts_tokens = []
|
||||
if verbose:
|
||||
for text in tqdm(texts, ncols=80, leave=False):
|
||||
|
6
setup.py
6
setup.py
@ -37,8 +37,8 @@ from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="pytorch_pretrained_bert",
|
||||
version="0.4.0",
|
||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
|
||||
version="0.5.0",
|
||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||
@ -55,7 +55,7 @@ setup(
|
||||
'tqdm'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
|
||||
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
|
||||
]
|
||||
},
|
||||
python_requires='>=3.5.0',
|
||||
|
Loading…
Reference in New Issue
Block a user