added conversion script

This commit is contained in:
thomwolf 2018-11-01 17:40:05 +01:00
parent 90d360a7a9
commit c5d532e5f6
2 changed files with 100 additions and 21 deletions

82
convert_tf_checkpoint.py Normal file
View File

@ -0,0 +1,82 @@
# coding=utf-8
"""Convert BERT checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import argparse
import tensorflow as tf
import torch
from .modeling_pytorch import BertConfig, BertModel
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Load weights from TF model
path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading {} with shape {}".format(name, shape))
array = tf.train.load_variable(path, name)
print("Numpy array shape {}".format(array.shape))
names.append(name)
arrays.append(array)
# Initialise PyTorch model and fill weights-in
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
for name, array in zip(names, arrays):
name = name[5:] # skip "bert/"
assert name[-2:] == ":0"
name = name[:-2]
name = name.split('/')
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
pointer.data = torch.from_numpy(array)
# Save pytorch-model
torch.save(model.state_dict(), args.pytorch_dump_path)
if __name__ == "__main__":
convert()
return None

View File

@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module):
self.variance_epsilon = variance_epsilon self.variance_epsilon = variance_epsilon
def forward(self, x): def forward(self, x):
# TODO check it's identical to TF implementation in details # TODO check it's identical to TF implementation in details (epsilon and axes)
u = x.mean(-1, keepdim=True) u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon) x = (x - u) / torch.sqrt(s + self.variance_epsilon)
@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module):
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) # inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
class BERTEmbeddings(nn.Module): class BERTEmbeddings(nn.Module):
def __init__(self, embedding_size, vocab_size, def __init__(self, config):
token_type_vocab_size, max_position_embeddings,
config):
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size) self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
@ -323,27 +321,32 @@ class BERTEncoder(nn.Module):
Return: Return:
float Tensor of shape [batch_size, seq_length, hidden_size] float Tensor of shape [batch_size, seq_length, hidden_size]
""" """
all_encoder_layers = []
for layer_module in self.layer: for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
return hidden_states all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BERTPooler(nn.Module): class BERTPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BERTPooler, self).__init__() super(BERTPooler, self).__init__()
layer = BERTLayer(n_ctx, cfg, scale=True) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.activation = nn.Tanh()
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states):
""" """
Args: Args:
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size] hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
Return: Return:
float Tensor of shape [batch_size, seq_length, hidden_size] float Tensor of shape [batch_size, hidden_size]
""" """
for layer_module in self.layer: # We "pool" the model by simply taking the hidden state corresponding
hidden_states = layer_module(hidden_states, attention_mask) # to the first token. We assume that this has been pre-trained
return hidden_states first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertModel(nn.Module): class BertModel(nn.Module):
@ -381,14 +384,6 @@ class BertModel(nn.Module):
is invalid. is invalid.
""" """
super(BertModel).__init__() super(BertModel).__init__()
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
self.embeddings = BERTEmbeddings(config) self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config) self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config) self.pooler = BERTPooler(config)
@ -396,4 +391,6 @@ class BertModel(nn.Module):
def forward(self, input_ids, token_type_ids, attention_mask): def forward(self, input_ids, token_type_ids, attention_mask):
embedding_output = self.embeddings(input_ids, token_type_ids) embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, attention_mask) all_encoder_layers = self.encoder(embedding_output, attention_mask)
return all_encoder_layers sequence_output = all_encoder_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output