mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix tf bert model
This commit is contained in:
parent
0537139b2b
commit
50c6bc4195
@ -58,7 +58,7 @@ class BertConfig(PretrainedConfig):
|
|||||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||||
layer in the Transformer encoder.
|
layer in the Transformer encoder.
|
||||||
hidden_act: The non-linear activation function (function or string) in the
|
hidden_act: The non-linear activation function (function or string) in the
|
||||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||||
layers in the embeddings, encoder, and pooler.
|
layers in the embeddings, encoder, and pooler.
|
||||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||||
|
@ -21,36 +21,62 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import pytorch_transformers
|
from pytorch_transformers import is_torch_available
|
||||||
|
|
||||||
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
|
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
|
||||||
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2)
|
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel
|
||||||
|
else:
|
||||||
|
BertForPreTraining, GPT2LMHeadModel = None, None
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2),
|
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
|
||||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2),
|
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
|
||||||
if model_type not in MODEL_CLASSES:
|
if model_type not in MODEL_CLASSES:
|
||||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||||
|
|
||||||
config_class, model_class, loading_fct = MODEL_CLASSES[model_type]
|
config_class, model_class, loading_fct, pt_model_class = MODEL_CLASSES[model_type]
|
||||||
|
|
||||||
# Initialise TF model
|
# Initialise TF model
|
||||||
config = config_class.from_json_file(config_file)
|
config = config_class.from_json_file(config_file)
|
||||||
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
||||||
model = model_class(config)
|
tf_model = model_class(config)
|
||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
model = loading_fct(model, config, pytorch_checkpoint_path)
|
tf_model = loading_fct(tf_model, config, pytorch_checkpoint_path)
|
||||||
|
|
||||||
|
if compare_with_pt_model:
|
||||||
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
|
tf_inputs = tf.constant(inputs_list)
|
||||||
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||||
|
|
||||||
|
pt_model = pt_model_class.from_pretrained(None,
|
||||||
|
config=config,
|
||||||
|
state_dict=torch.load(pytorch_checkpoint_path,
|
||||||
|
map_location='cpu'))
|
||||||
|
pt_inputs = torch.tensor(inputs_list)
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(pt_inputs)
|
||||||
|
|
||||||
|
np_pt = pto[0].detach().numpy()
|
||||||
|
np_tf = tfo[0].numpy()
|
||||||
|
diff = np.amax(np.abs(np_pt - np_tf))
|
||||||
|
print("Max absolute difference between models outputs {}".format(diff))
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
print("Save TensorFlow model to {}".format(tf_dump_path))
|
print("Save TensorFlow model to {}".format(tf_dump_path))
|
||||||
model.save_weights(tf_dump_path)
|
tf_model.save_weights(tf_dump_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -77,8 +103,12 @@ if __name__ == "__main__":
|
|||||||
type = str,
|
type = str,
|
||||||
required = True,
|
required = True,
|
||||||
help = "Path to the output Tensorflow dump file.")
|
help = "Path to the output Tensorflow dump file.")
|
||||||
|
parser.add_argument("--compare_with_pt_model",
|
||||||
|
action='store_true',
|
||||||
|
help = "Compare Tensorflow and PyTorch model predictions.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
||||||
args.pytorch_checkpoint_path,
|
args.pytorch_checkpoint_path,
|
||||||
args.config_file,
|
args.config_file,
|
||||||
args.tf_dump_path)
|
args.tf_dump_path,
|
||||||
|
compare_with_pt_model=args.compare_with_pt_model)
|
||||||
|
@ -118,19 +118,24 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Implementation of the gelu activation function.
|
""" Original Implementation of the gelu activation function in Google Bert repo when initialy created.
|
||||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
Also see https://arxiv.org/abs/1606.08415
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
"""
|
"""
|
||||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
def gelu_new(x):
|
||||||
|
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
def swish(x):
|
def swish(x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -195,7 +200,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
mixed_value_layer = self.value(hidden_states)
|
mixed_value_layer = self.value(hidden_states)
|
||||||
@ -207,8 +212,9 @@ class BertSelfAttention(nn.Module):
|
|||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
if attention_mask is not None:
|
||||||
attention_scores = attention_scores + attention_mask
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
# Normalize the attention scores to probabilities.
|
# Normalize the attention scores to probabilities.
|
||||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
@ -275,7 +281,7 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
def forward(self, input_tensor, attention_mask=None, head_mask=None):
|
||||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||||
attention_output = self.output(self_outputs[0], input_tensor)
|
attention_output = self.output(self_outputs[0], input_tensor)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
@ -318,7 +324,7 @@ class BertLayer(nn.Module):
|
|||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
@ -334,7 +340,7 @@ class BertEncoder(nn.Module):
|
|||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
@ -77,6 +77,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
|||||||
|
|
||||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
|
all_pytorch_weights = set(list(state_dict.keys()))
|
||||||
for symbolic_weight in symbolic_weights:
|
for symbolic_weight in symbolic_weights:
|
||||||
name = symbolic_weight.name
|
name = symbolic_weight.name
|
||||||
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
|
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
|
||||||
@ -91,7 +92,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
|||||||
name[-1] = 'weight'
|
name[-1] = 'weight'
|
||||||
|
|
||||||
name = '.'.join(name)
|
name = '.'.join(name)
|
||||||
assert name in state_dict
|
assert name in state_dict, "{} not found in PyTorch model".format(name)
|
||||||
array = state_dict[name].numpy()
|
array = state_dict[name].numpy()
|
||||||
|
|
||||||
if transpose:
|
if transpose:
|
||||||
@ -106,14 +107,28 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
|||||||
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
|
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
|
||||||
|
|
||||||
weight_value_tuples.append((symbolic_weight, array))
|
weight_value_tuples.append((symbolic_weight, array))
|
||||||
|
all_pytorch_weights.discard(name)
|
||||||
|
|
||||||
K.batch_set_value(weight_value_tuples)
|
K.batch_set_value(weight_value_tuples)
|
||||||
|
|
||||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
|
logger.info("Weights not loaded: {}".format(all_pytorch_weights))
|
||||||
|
|
||||||
return tf_model
|
return tf_model
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
|
""" Gaussian Error Linear Unit.
|
||||||
|
Original Implementation of the gelu activation function in Google Bert repo when initialy created.
|
||||||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||||
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
|
||||||
|
return x * cdf
|
||||||
|
|
||||||
|
def gelu_new(x):
|
||||||
"""Gaussian Error Linear Unit.
|
"""Gaussian Error Linear Unit.
|
||||||
This is a smoother version of the RELU.
|
This is a smoother version of the RELU.
|
||||||
Original paper: https://arxiv.org/abs/1606.08415
|
Original paper: https://arxiv.org/abs/1606.08415
|
||||||
@ -126,14 +141,14 @@ def gelu(x):
|
|||||||
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||||
return x * cdf
|
return x * cdf
|
||||||
|
|
||||||
|
|
||||||
def swish(x):
|
def swish(x):
|
||||||
return x * tf.sigmoid(x)
|
return x * tf.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
|
ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
|
||||||
"relu": tf.keras.activations.relu,
|
"relu": tf.keras.activations.relu,
|
||||||
"swish": tf.keras.layers.Activation(swish)}
|
"swish": tf.keras.layers.Activation(swish),
|
||||||
|
"gelu_new": tf.keras.layers.Activation(gelu_new)}
|
||||||
|
|
||||||
|
|
||||||
class TFBertEmbeddings(tf.keras.layers.Layer):
|
class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||||
@ -263,8 +278,10 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k)
|
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores
|
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores
|
||||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||||
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
|
||||||
attention_scores = attention_scores + attention_mask
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
# Normalize the attention scores to probabilities.
|
# Normalize the attention scores to probabilities.
|
||||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||||
@ -438,31 +455,33 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, input_embeddings, **kwargs):
|
||||||
super(TFBertLMPredictionHead, self).__init__(**kwargs)
|
super(TFBertLMPredictionHead, self).__init__(**kwargs)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.transform = TFBertPredictionHeadTransform(config, name='transform')
|
self.transform = TFBertPredictionHeadTransform(config, name='transform')
|
||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
# The output weights are the same as the input embeddings, but there is
|
||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
self.decoder = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name='decoder')
|
self.input_embeddings = input_embeddings
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.bias = self.add_weight(shape=(self.vocab_size,),
|
self.bias = self.add_weight(shape=(self.vocab_size,),
|
||||||
initializer='zeros',
|
initializer='zeros',
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name='bias')
|
name='bias')
|
||||||
|
super(TFBertLMPredictionHead, self).build(input_shape)
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = self.decoder(hidden_states) + self.bias
|
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||||
|
hidden_states = hidden_states + self.bias
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFBertMLMHead(tf.keras.layers.Layer):
|
class TFBertMLMHead(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, input_embeddings, **kwargs):
|
||||||
super(TFBertMLMHead, self).__init__(**kwargs)
|
super(TFBertMLMHead, self).__init__(**kwargs)
|
||||||
self.predictions = TFBertLMPredictionHead(config, name='predictions')
|
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name='predictions')
|
||||||
|
|
||||||
def call(self, sequence_output):
|
def call(self, sequence_output):
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
@ -716,12 +735,13 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||||
|
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm')
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
sequence_output, pooled_output = outputs[:2]
|
sequence_output, pooled_output = outputs[:2]
|
||||||
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
|
prediction_scores = self.cls_mlm(sequence_output, training=training)
|
||||||
seq_relationship_score = self.cls_nsp(pooled_output)
|
seq_relationship_score = self.cls_nsp(pooled_output)
|
||||||
|
|
||||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
@ -757,12 +777,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
|||||||
super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
|
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm')
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
|
prediction_scores = self.cls_mlm(sequence_output, training=training)
|
||||||
|
|
||||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
|
||||||
|
@ -100,9 +100,14 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
|||||||
|
|
||||||
weight_value_tuples.append((symbolic_weight, array))
|
weight_value_tuples.append((symbolic_weight, array))
|
||||||
|
|
||||||
|
state_dict.pop(name)
|
||||||
|
|
||||||
K.batch_set_value(weight_value_tuples)
|
K.batch_set_value(weight_value_tuples)
|
||||||
|
|
||||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
|
assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys()))
|
||||||
|
|
||||||
return tf_model
|
return tf_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,6 +222,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
|
||||||
|
|
||||||
model_args: (`optional`) Sequence of positional arguments:
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
@ -289,42 +290,45 @@ class PreTrainedModel(nn.Module):
|
|||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path is not None:
|
||||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
if from_tf:
|
|
||||||
# Directly load from a TensorFlow checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
|
||||||
else:
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
|
||||||
else:
|
|
||||||
if from_tf:
|
|
||||||
# Directly load from a TensorFlow checkpoint
|
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
|
||||||
else:
|
|
||||||
archive_file = pretrained_model_name_or_path
|
|
||||||
# redirect to the cache, if necessary
|
|
||||||
try:
|
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
|
||||||
except EnvironmentError as e:
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
logger.error(
|
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||||
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
archive_file))
|
if from_tf:
|
||||||
|
# Directly load from a TensorFlow checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||||
|
else:
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
if from_tf:
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
# Directly load from a TensorFlow checkpoint
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
"associated to this path or url.".format(
|
else:
|
||||||
pretrained_model_name_or_path,
|
archive_file = pretrained_model_name_or_path
|
||||||
', '.join(cls.pretrained_model_archive_map.keys()),
|
# redirect to the cache, if necessary
|
||||||
archive_file))
|
try:
|
||||||
raise e
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
if resolved_archive_file == archive_file:
|
except EnvironmentError as e:
|
||||||
logger.info("loading weights file {}".format(archive_file))
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
|
logger.error(
|
||||||
|
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||||
|
archive_file))
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
|
"associated to this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
', '.join(cls.pretrained_model_archive_map.keys()),
|
||||||
|
archive_file))
|
||||||
|
raise e
|
||||||
|
if resolved_archive_file == archive_file:
|
||||||
|
logger.info("loading weights file {}".format(archive_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading weights file {} from cache at {}".format(
|
||||||
|
archive_file, resolved_archive_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading weights file {} from cache at {}".format(
|
resolved_archive_file = None
|
||||||
archive_file, resolved_archive_file))
|
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user