mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
added conversion script
This commit is contained in:
parent
90d360a7a9
commit
c5d532e5f6
82
convert_tf_checkpoint.py
Normal file
82
convert_tf_checkpoint.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""Convert BERT checkpoint."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .modeling_pytorch import BertConfig, BertModel
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--tf_checkpoint_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path the TensorFlow checkpoint path.")
|
||||||
|
parser.add_argument("--bert_config_file",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
parser.add_argument("--pytorch_dump_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path to the output PyTorch model.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def convert():
|
||||||
|
# Load weights from TF model
|
||||||
|
path = args.tf_checkpoint_path
|
||||||
|
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||||
|
|
||||||
|
init_vars = tf.train.list_variables(path)
|
||||||
|
names = []
|
||||||
|
arrays = []
|
||||||
|
for name, shape in init_vars:
|
||||||
|
print("Loading {} with shape {}".format(name, shape))
|
||||||
|
array = tf.train.load_variable(path, name)
|
||||||
|
print("Numpy array shape {}".format(array.shape))
|
||||||
|
names.append(name)
|
||||||
|
arrays.append(array)
|
||||||
|
|
||||||
|
# Initialise PyTorch model and fill weights-in
|
||||||
|
config = BertConfig.from_json_file(args.bert_config_file)
|
||||||
|
model = BertModel(config)
|
||||||
|
for name, array in zip(names, arrays):
|
||||||
|
name = name[5:] # skip "bert/"
|
||||||
|
assert name[-2:] == ":0"
|
||||||
|
name = name[:-2]
|
||||||
|
name = name.split('/')
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||||
|
l = re.split(r'(\d+)', m_name)
|
||||||
|
else:
|
||||||
|
l = [m_name]
|
||||||
|
pointer = getattr(pointer, l[0])
|
||||||
|
if len(l) >= 2:
|
||||||
|
num = int(l[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert()
|
||||||
|
return None
|
@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module):
|
|||||||
self.variance_epsilon = variance_epsilon
|
self.variance_epsilon = variance_epsilon
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# TODO check it's identical to TF implementation in details
|
# TODO check it's identical to TF implementation in details (epsilon and axes)
|
||||||
u = x.mean(-1, keepdim=True)
|
u = x.mean(-1, keepdim=True)
|
||||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||||
@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module):
|
|||||||
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||||
|
|
||||||
class BERTEmbeddings(nn.Module):
|
class BERTEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_size, vocab_size,
|
def __init__(self, config):
|
||||||
token_type_vocab_size, max_position_embeddings,
|
|
||||||
config):
|
|
||||||
|
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
||||||
|
|
||||||
@ -323,27 +321,32 @@ class BERTEncoder(nn.Module):
|
|||||||
Return:
|
Return:
|
||||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||||
"""
|
"""
|
||||||
|
all_encoder_layers = []
|
||||||
for layer_module in self.layer:
|
for layer_module in self.layer:
|
||||||
hidden_states = layer_module(hidden_states, attention_mask)
|
hidden_states = layer_module(hidden_states, attention_mask)
|
||||||
return hidden_states
|
all_encoder_layers.append(hidden_states)
|
||||||
|
return all_encoder_layers
|
||||||
|
|
||||||
|
|
||||||
class BERTPooler(nn.Module):
|
class BERTPooler(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BERTPooler, self).__init__()
|
super(BERTPooler, self).__init__()
|
||||||
layer = BERTLayer(n_ctx, cfg, scale=True)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||||
Return:
|
Return:
|
||||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
float Tensor of shape [batch_size, hidden_size]
|
||||||
"""
|
"""
|
||||||
for layer_module in self.layer:
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
hidden_states = layer_module(hidden_states, attention_mask)
|
# to the first token. We assume that this has been pre-trained
|
||||||
return hidden_states
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
class BertModel(nn.Module):
|
class BertModel(nn.Module):
|
||||||
@ -381,14 +384,6 @@ class BertModel(nn.Module):
|
|||||||
is invalid.
|
is invalid.
|
||||||
"""
|
"""
|
||||||
super(BertModel).__init__()
|
super(BertModel).__init__()
|
||||||
config = copy.deepcopy(config)
|
|
||||||
if not is_training:
|
|
||||||
config.hidden_dropout_prob = 0.0
|
|
||||||
config.attention_probs_dropout_prob = 0.0
|
|
||||||
|
|
||||||
batch_size = input_ids.size(0)
|
|
||||||
seq_length = input_ids.size(1)
|
|
||||||
|
|
||||||
self.embeddings = BERTEmbeddings(config)
|
self.embeddings = BERTEmbeddings(config)
|
||||||
self.encoder = BERTEncoder(config)
|
self.encoder = BERTEncoder(config)
|
||||||
self.pooler = BERTPooler(config)
|
self.pooler = BERTPooler(config)
|
||||||
@ -396,4 +391,6 @@ class BertModel(nn.Module):
|
|||||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||||
all_encoder_layers = self.encoder(embedding_output, attention_mask)
|
all_encoder_layers = self.encoder(embedding_output, attention_mask)
|
||||||
return all_encoder_layers
|
sequence_output = all_encoder_layers[-1]
|
||||||
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
return all_encoder_layers, pooled_output
|
||||||
|
Loading…
Reference in New Issue
Block a user