add tf bert files

This commit is contained in:
thomwolf 2019-09-05 02:27:39 +02:00
parent 85df4f7cca
commit bffd17a43d
5 changed files with 1728 additions and 0 deletions

View File

@ -83,6 +83,9 @@ def url_to_filename(url, etag=None):
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
@ -93,6 +96,9 @@ def url_to_filename(url, etag=None):
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
if url.endswith('.h5'):
filename += '.h5'
return filename

View File

@ -0,0 +1,832 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import math
import os
import sys
from io import open
import numpy as np
import tensorflow as tf
from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
}
def load_pt_weights_in_bert(tf_model, config, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError:
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions.")
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
# Load pytorch model
state_dict = torch.load(pt_path, map_location='cpu')
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
name = name.replace(':0', '')
name = name.replace('layer_', 'layer/')
name = name.split('/')
name = name[1:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings':
name[-1] = 'weight'
name = '.'.join(name)
assert name in state_dict
array = state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
return tf_model
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def swish(x):
return x * tf.sigmoid(x)
ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
"relu": tf.keras.activations.relu,
"swish": tf.keras.layers.Activation(swish)}
class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, **kwargs):
super(TFBertEmbeddings, self).__init__(**kwargs)
self.word_embeddings = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size, name='word_embeddings')
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False):
input_ids, position_ids, token_type_ids = inputs
seq_length = tf.shape(input_ids)[1]
if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
if training:
embeddings = self.dropout(embeddings)
return embeddings
class TFBertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertSelfAttention, self).__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(self.all_head_size, name='query')
self.key = tf.keras.layers.Dense(self.all_head_size, name='key')
self.value = tf.keras.layers.Dense(self.all_head_size, name='value')
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
batch_size = tf.shape(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
if training:
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(context_layer,
(batch_size, -1, self.all_head_size)) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs
class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertSelfOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False):
hidden_states, input_tensor = inputs
hidden_states = self.dense(hidden_states)
if training:
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TFBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertAttention, self).__init__(**kwargs)
self.self_attention = TFBertSelfAttention(config, name='self')
self.dense_output = TFBertSelfOutput(config, name='output')
def prune_heads(self, heads):
raise NotImplementedError
def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs
self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertIntermediate, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False):
hidden_states, input_tensor = inputs
hidden_states = self.dense(hidden_states)
if training:
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TFBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertLayer, self).__init__(**kwargs)
self.attention = TFBertAttention(config, name='attention')
self.intermediate = TFBertIntermediate(config, name='intermediate')
self.bert_output = TFBertOutput(config, name='output')
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output([intermediate_output, attention_output], training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertEncoder, self).__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
return pooled_output
class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPredictionHeadTransform, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertLMPredictionHead, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
self.transform = TFBertPredictionHeadTransform(config, name='transform')
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name='decoder')
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertMLMHead, self).__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, name='predictions')
def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertMainLayer, self).__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name='embeddings')
self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler')
# self.apply(self.init_weights) # TODO check weights initialization
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs.pop('input_ids')
attention_mask = inputs.pop('attention_mask', None)
token_type_ids = inputs.pop('token_type_ids', None)
position_ids = inputs.pop('position_ids', None)
head_mask = inputs.pop('head_mask', None)
assert len(inputs) == 0, "Unexpected inputs detected: {}. Check inputs dict key names.".format(list(inputs.keys()))
if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1)
if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if not head_mask is None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
class TFBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_pt_weights_in_bert
base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(TFBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
raise NotImplementedError
BERT_START_DOCSTRING = r""" The BERT model was proposed in
`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
pre-trained using a combination of masked language modeling objective and next sentence prediction
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
.. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
https://arxiv.org/abs/1810.04805
.. _`tf.keras.Model`:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
Important note on the model inputs:
The inputs of the TF 2.0 models are slightly different from the PyTorch ones since
TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor.
More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`.
There are three possibilities to gather and feed the inputs to the model:
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
Parameters:
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertModel(TFBertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertModel.from_pretrained('bert-base-uncased')
input_ids = tf.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(TFBertModel, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
return outputs
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForPreTraining(TFBertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
prediction_scores, seq_relationship_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForPreTraining, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# self.apply(self.init_weights) # TODO check added weights initialization
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.cls_mlm(sequence_output)
seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
# if masked_lm_labels is not None and next_sentence_label is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
# total_loss = masked_lm_loss + next_sentence_loss
# outputs = (total_loss,) + outputs
# TODO add example with losses using model.compile and a dictionary of losses (give names to the output layers)
return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForMaskedLM, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
# self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output = outputs[0]
prediction_scores = self.cls_mlm(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
# if masked_lm_labels is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
# outputs = (masked_lm_loss,) + outputs
# TODO example with losses
return outputs # prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
r"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next sequence prediction (classification) loss.
**seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
seq_relationship_scores = outputs[0]
"""
def __init__(self, config):
super(TFBertForNextSentencePrediction, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# self.apply(self.init_weights)
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
pooled_output = outputs[1]
seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
# if next_sentence_label is not None:
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
# outputs = (next_sentence_loss,) + outputs
return outputs # seq_relationship_score, (hidden_states), (attentions)

View File

@ -0,0 +1,255 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import os
import tensorflow as tf
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
logger = logging.getLogger(__name__)
class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
:class:`~pytorch_transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
- ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
- ``path``: a path (string) to the TensorFlow checkpoint.
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
load_pt_weights = lambda model, config, path: None
base_model_prefix = ""
def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
# Save config in model
self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
raise NotImplementedError
def _tie_or_clone_weights(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
raise NotImplementedError
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings``
Pointer to the input tokens Embeddings Module of the model
"""
raise NotImplementedError
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
"""
raise NotImplementedError
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
"""
raise NotImplementedError
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with ``model.train()``
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
It is up to you to train those weights with a downstream fine-tuning task.
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
Parameters:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
from_pt: (`optional`) boolean, default False:
Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
Examples::
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)
"""
config = kwargs.pop('config', None)
cache_dir = kwargs.pop('cache_dir', None)
from_pt = kwargs.pop('from_pt', False)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config
if config is None:
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download,
**kwargs
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if from_pt:
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
# Instantiate model.
model = cls(config, *model_args, **model_kwargs)
if from_pt:
# Load from a PyTorch checkpoint
return cls.load_pt_weights(model, config, resolved_archive_file)
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
ret = model(inputs, training=False) # build the network with dummy inputs
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True)
ret = model(inputs, training=False) # Make sure restore ops are run
# if hasattr(model, 'tie_weights'):
# model.tie_weights() # TODO make sure word embedding weights are still tied
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
return model, loading_info
return model

View File

@ -0,0 +1,308 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import shutil
import json
import random
import uuid
import unittest
import logging
import tensorflow as tf
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if '_range' in key or '_std' in key:
setattr(configs_no_init, key, 0.0)
return configs_no_init
class TFCommonTestCases:
class TFCommonModelTester(unittest.TestCase):
model_tester = None
all_model_classes = ()
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# configs_no_init = _config_zero_init(config)
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# for name, param in model.named_parameters():
# if param.requires_grad:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_attentions = True
# config.output_hidden_states = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# attentions = outputs[-1]
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, False)
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# out_len = len(outputs)
# # Check attention is always last and order is fine
# config.output_attentions = True
# config.output_hidden_states = True
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# self.assertEqual(out_len+1, len(outputs))
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, True)
# attentions = outputs[-1]
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
def test_headmasking(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config.output_attentions = True
# config.output_hidden_states = True
# configs_no_init = _config_zero_init(config) # To be sure we have no Nan
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# model.eval()
# # Prepare head_mask
# # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
# head_mask = torch.ones(self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads)
# head_mask[0, 0] = 0
# head_mask[-1, :-1] = 0
# head_mask.requires_grad_(requires_grad=True)
# inputs = inputs_dict.copy()
# inputs['head_mask'] = head_mask
# outputs = model(**inputs)
# # Test that we can get a gradient back for importance score computation
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = head_mask.grad
# attentions = outputs[-1]
# hidden_states = outputs[-2]
# # Remove Nan
# self.assertIsNotNone(multihead_outputs)
# self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
# self.assertAlmostEqual(
# attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
# self.assertNotEqual(
# attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
# self.assertNotEqual(
# attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
# self.assertAlmostEqual(
# attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
# self.assertNotEqual(
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
def test_head_pruning(self):
pass
# if not self.test_pruning:
# return
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_attentions = True
# config.output_hidden_states = False
# model = model_class(config=config)
# model.eval()
# heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
# -1: [0]}
# model.prune_heads(heads_to_prune)
# outputs = model(**inputs_dict)
# attentions = outputs[-1]
# self.assertEqual(
# attentions[0].shape[-3], 1)
# self.assertEqual(
# attentions[1].shape[-3], self.model_tester.num_attention_heads)
# self.assertEqual(
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_hidden_states = True
# config.output_attentions = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# hidden_states = outputs[-1]
# self.assertEqual(model.config.output_attentions, False)
# self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
# self.assertListEqual(
# list(hidden_states[0].shape[-2:]),
# [self.model_tester.seq_length, self.model_tester.hidden_size])
def test_resize_tokens_embeddings(self):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# if not self.test_resize_embeddings:
# return
# for model_class in self.all_model_classes:
# config = copy.deepcopy(original_config)
# model = model_class(config)
# model_vocab_size = config.vocab_size
# # Retrieve the embeddings and clone theme
# model_embed = model.resize_token_embeddings(model_vocab_size)
# cloned_embeddings = model_embed.weight.clone()
# # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
# model_embed = model.resize_token_embeddings(model_vocab_size + 10)
# self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
# # Check that it actually resizes the embeddings matrix
# self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
# model_embed = model.resize_token_embeddings(model_vocab_size - 15)
# self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
# # Check that it actually resizes the embeddings matrix
# self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# # Check that adding and removing tokens has not modified the first part of the embedding matrix.
# models_equal = True
# for p1, p2 in zip(cloned_embeddings, model_embed.weight):
# if p1.data.ne(p2.data).sum() > 0:
# models_equal = False
# self.assertTrue(models_equal)
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# def check_same_values(layer_1, layer_2):
# equal = True
# for p1, p2 in zip(layer_1.weight, layer_2.weight):
# if p1.data.ne(p2.data).sum() > 0:
# equal = False
# return equal
# for model_class in self.all_model_classes:
# if not hasattr(model_class, 'tie_weights'):
# continue
# config.torchscript = True
# model_not_tied = model_class(config)
# params_not_tied = list(model_not_tied.parameters())
# config_tied = copy.deepcopy(config)
# config_tied.torchscript = False
# model_tied = model_class(config_tied)
# params_tied = list(model_tied.parameters())
# # Check that the embedding layer and decoding layer are the same in size and in value
# self.assertGreater(len(params_not_tied), len(params_tied))
# # Check that after resize they remain tied.
# model_tied.resize_token_embeddings(config.vocab_size + 10)
# params_tied_2 = list(model_tied.parameters())
# self.assertGreater(len(params_not_tied), len(params_tied))
# self.assertEqual(len(params_tied_2), len(params_tied))
def ids_tensor(shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return tf.constant(values, shape=shape)
class TFModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
pass
# logging.basicConfig(level=logging.INFO)
# for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# config = BertConfig.from_pretrained(model_name)
# self.assertIsNotNone(config)
# self.assertIsInstance(config, PretrainedConfig)
# model = BertModel.from_pretrained(model_name)
# model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
# self.assertIsNotNone(model)
# self.assertIsInstance(model, PreTrainedModel)
# for value in loading_info.values():
# self.assertEqual(len(value), 0)
# config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
# model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,327 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import tensorflow as tf
from pytorch_transformers import (BertConfig)
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFBertModel,)
# BertForMaskedLM, BertForNextSentencePrediction,
# BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
# BertForTokenClassification)
class TFBertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = TFBertModel(config=config)
# model.eval()
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
sequence_output, pooled_output = model(inputs)
inputs = [input_ids, input_mask]
sequence_output, pooled_output = model(inputs)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output.numpy(),
"pooled_output": pooled_output.numpy(),
}
self.parent.assertListEqual(
list(result["sequence_output"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForMaskedLM(config=config)
# model.eval()
# loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForNextSentencePrediction(config=config)
# model.eval()
# loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
# result = {
# "loss": loss,
# "seq_relationship_score": seq_relationship_score,
# }
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].size()),
# [self.batch_size, 2])
# self.check_loss_output(result)
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForPreTraining(config=config)
# model.eval()
# loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# "seq_relationship_score": seq_relationship_score,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].size()),
# [self.batch_size, 2])
# self.check_loss_output(result)
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# model = BertForQuestionAnswering(config=config)
# model.eval()
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
# "start_logits": start_logits,
# "end_logits": end_logits,
# }
# self.parent.assertListEqual(
# list(result["start_logits"].size()),
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
# list(result["end_logits"].size()),
# [self.batch_size, self.seq_length])
# self.check_loss_output(result)
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_labels = self.num_labels
# model = BertForSequenceClassification(config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_labels = self.num_labels
# model = BertForTokenClassification(config=config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.seq_length, self.num_labels])
# self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
pass
# config.num_choices = self.num_choices
# model = BertForMultipleChoice(config=config)
# model.eval()
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
# multiple_choice_token_type_ids,
# multiple_choice_input_mask,
# choice_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_choices])
# self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = TFBertModelTest.TFBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_bert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()