fix tf bert model

This commit is contained in:
thomwolf 2019-09-09 17:46:01 +02:00
parent 0537139b2b
commit 50c6bc4195
6 changed files with 129 additions and 63 deletions

View File

@ -58,7 +58,7 @@ class BertConfig(PretrainedConfig):
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention

View File

@ -21,36 +21,62 @@ from __future__ import print_function
import argparse
import tensorflow as tf
import pytorch_transformers
from pytorch_transformers import is_torch_available
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2)
if is_torch_available():
import torch
import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel
else:
BertForPreTraining, GPT2LMHeadModel = None, None
import logging
logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2),
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct = MODEL_CLASSES[model_type]
config_class, model_class, loading_fct, pt_model_class = MODEL_CLASSES[model_type]
# Initialise TF model
config = config_class.from_json_file(config_file)
print("Building TensorFlow model from configuration: {}".format(str(config)))
model = model_class(config)
tf_model = model_class(config)
# Load weights from tf checkpoint
model = loading_fct(model, config, pytorch_checkpoint_path)
tf_model = loading_fct(tf_model, config, pytorch_checkpoint_path)
if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
pto = pt_model(pt_inputs)
np_pt = pto[0].detach().numpy()
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff))
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
model.save_weights(tf_dump_path)
tf_model.save_weights(tf_dump_path)
if __name__ == "__main__":
@ -77,8 +103,12 @@ if __name__ == "__main__":
type = str,
required = True,
help = "Path to the output Tensorflow dump file.")
parser.add_argument("--compare_with_pt_model",
action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.")
args = parser.parse_args()
convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.pytorch_checkpoint_path,
args.config_file,
args.tf_dump_path)
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)

View File

@ -118,19 +118,24 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
def gelu(x):
"""Implementation of the gelu activation function.
""" Original Implementation of the gelu activation function in Google Bert repo when initialy created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
try:
@ -195,7 +200,7 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, head_mask=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
@ -207,8 +212,9 @@ class BertSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
@ -275,7 +281,7 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_tensor, attention_mask, head_mask=None):
def forward(self, input_tensor, attention_mask=None, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -318,7 +324,7 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, head_mask=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
@ -334,7 +340,7 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):

View File

@ -77,6 +77,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
@ -91,7 +92,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
name[-1] = 'weight'
name = '.'.join(name)
assert name in state_dict
assert name in state_dict, "{} not found in PyTorch model".format(name)
array = state_dict[name].numpy()
if transpose:
@ -106,14 +107,28 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights not loaded: {}".format(all_pytorch_weights))
return tf_model
def gelu(x):
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initialy created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
return x * cdf
def gelu_new(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
@ -126,14 +141,14 @@ def gelu(x):
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def swish(x):
return x * tf.sigmoid(x)
ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
"relu": tf.keras.activations.relu,
"swish": tf.keras.layers.Activation(swish)}
"swish": tf.keras.layers.Activation(swish),
"gelu_new": tf.keras.layers.Activation(gelu_new)}
class TFBertEmbeddings(tf.keras.layers.Layer):
@ -263,8 +278,10 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
@ -438,31 +455,33 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config, input_embeddings, **kwargs):
super(TFBertLMPredictionHead, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
self.transform = TFBertPredictionHeadTransform(config, name='transform')
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name='decoder')
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
super(TFBertLMPredictionHead, self).build(input_shape)
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config, input_embeddings, **kwargs):
super(TFBertMLMHead, self).__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, name='predictions')
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name='predictions')
def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
@ -716,12 +735,13 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
prediction_scores = self.cls_mlm(sequence_output, training=training)
seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
@ -757,12 +777,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output = outputs[0]
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
prediction_scores = self.cls_mlm(sequence_output, training=training)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

View File

@ -100,9 +100,14 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
weight_value_tuples.append((symbolic_weight, array))
state_dict.pop(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys()))
return tf_model

View File

@ -222,6 +222,7 @@ class PreTrainedModel(nn.Module):
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
@ -289,42 +290,45 @@ class PreTrainedModel(nn.Module):
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError as e:
if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
resolved_archive_file = None
# Instantiate model.
model = cls(config, *model_args, **model_kwargs)