diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 37ebc57fb3c..94ada93d91f 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -79,6 +79,9 @@ def url_to_filename(url, etag=None): Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, delimited by a period. + If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name + so that TF 2.0 can identify it as a HDF5 file + (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) """ url_bytes = url.encode('utf-8') url_hash = sha256(url_bytes) @@ -89,6 +92,9 @@ def url_to_filename(url, etag=None): etag_hash = sha256(etag_bytes) filename += '.' + etag_hash.hexdigest() + if url.endswith('.h5'): + filename += '.h5' + return filename diff --git a/pytorch_transformers/modeling_tf_bert.py b/pytorch_transformers/modeling_tf_bert.py new file mode 100644 index 00000000000..dda713fe7dc --- /dev/null +++ b/pytorch_transformers/modeling_tf_bert.py @@ -0,0 +1,832 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 BERT model. """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import math +import os +import sys +from io import open + +import numpy as np +import tensorflow as tf + +from .configuration_bert import BertConfig +from .modeling_tf_utils import TFPreTrainedModel +from .file_utils import add_start_docstrings + +logger = logging.getLogger(__name__) + + +TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5", + 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5", +} + + +def load_pt_weights_in_bert(tf_model, config, pytorch_checkpoint_path): + """ Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format + We use HDF5 to easily do transfer learning + (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). + """ + try: + import re + import torch + import numpy + from tensorflow.python.keras import backend as K + except ImportError: + logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see " + "https://pytorch.org/ for installation instructions.") + raise + + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info("Loading PyTorch weights from {}".format(pt_path)) + # Load pytorch model + state_dict = torch.load(pt_path, map_location='cpu') + + inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] + tf_inputs = tf.constant(inputs_list) + tfo = tf_model(tf_inputs, training=False) # build the network + + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights + weight_value_tuples = [] + for symbolic_weight in symbolic_weights: + name = symbolic_weight.name + name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be + name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers) + name = name.replace(':0', '') + name = name.replace('layer_', 'layer/') + name = name.split('/') + name = name[1:] + + transpose = bool(name[-1] == 'kernel') + if name[-1] == 'kernel' or name[-1] == 'embeddings': + name[-1] = 'weight' + + name = '.'.join(name) + assert name in state_dict + array = state_dict[name].numpy() + + if transpose: + array = numpy.transpose(array) + + try: + assert list(symbolic_weight.shape) == list(array.shape) + except AssertionError as e: + e.args += (symbolic_weight.shape, array.shape) + raise e + + logger.info("Initialize TF weight {}".format(symbolic_weight.name)) + + weight_value_tuples.append((symbolic_weight, array)) + + K.batch_set_value(weight_value_tuples) + + tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run + return tf_model + + +def gelu(x): + """Gaussian Error Linear Unit. + This is a smoother version of the RELU. + Original paper: https://arxiv.org/abs/1606.08415 + Args: + x: float Tensor to perform activation. + Returns: + `x` with the GELU activation applied. + """ + cdf = 0.5 * (1.0 + tf.tanh( + (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + return x * cdf + + +def swish(x): + return x * tf.sigmoid(x) + + +ACT2FN = {"gelu": tf.keras.layers.Activation(gelu), + "relu": tf.keras.activations.relu, + "swish": tf.keras.layers.Activation(swish)} + + +class TFBertEmbeddings(tf.keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config, **kwargs): + super(TFBertEmbeddings, self).__init__(**kwargs) + self.word_embeddings = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size, name='word_embeddings') + self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings') + self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings') + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, inputs, training=False): + input_ids, position_ids, token_type_ids = inputs + + seq_length = tf.shape(input_ids)[1] + if position_ids is None: + position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] + if token_type_ids is None: + token_type_ids = tf.fill(tf.shape(input_ids), 0) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + if training: + embeddings = self.dropout(embeddings) + return embeddings + + +class TFBertSelfAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertSelfAttention, self).__init__(**kwargs) + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = tf.keras.layers.Dense(self.all_head_size, name='query') + self.key = tf.keras.layers.Dense(self.all_head_size, name='key') + self.value = tf.keras.layers.Dense(self.all_head_size, name='value') + + self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, batch_size): + x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask = inputs + + batch_size = tf.shape(hidden_states)[0] + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k) + dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores + attention_scores = attention_scores / tf.math.sqrt(dk) + # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = tf.nn.softmax(attention_scores, axis=-1) + + if training: + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = tf.matmul(attention_probs, value_layer) + + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + context_layer = tf.reshape(context_layer, + (batch_size, -1, self.all_head_size)) # (batch_size, seq_len_q, all_head_size) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + + +class TFBertSelfOutput(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertSelfOutput, self).__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, inputs, training=False): + hidden_states, input_tensor = inputs + + hidden_states = self.dense(hidden_states) + if training: + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class TFBertAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertAttention, self).__init__(**kwargs) + self.self_attention = TFBertSelfAttention(config, name='self') + self.dense_output = TFBertSelfOutput(config, name='output') + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, inputs, training=False): + input_tensor, attention_mask, head_mask = inputs + + self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training) + attention_output = self.dense_output([self_outputs[0], input_tensor], training=training) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class TFBertIntermediate(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertIntermediate, self).__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.intermediate_size, name='dense') + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class TFBertOutput(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertOutput, self).__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, inputs, training=False): + hidden_states, input_tensor = inputs + + hidden_states = self.dense(hidden_states) + if training: + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class TFBertLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertLayer, self).__init__(**kwargs) + self.attention = TFBertAttention(config, name='attention') + self.intermediate = TFBertIntermediate(config, name='intermediate') + self.bert_output = TFBertOutput(config, name='output') + + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask = inputs + + attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.bert_output([intermediate_output, attention_output], training=training) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + + +class TFBertEncoder(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertEncoder, self).__init__(**kwargs) + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)] + + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask = inputs + + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # outputs, (hidden states), (attentions) + + +class TFBertPooler(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertPooler, self).__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense') + + def call(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + return pooled_output + + +class TFBertPredictionHeadTransform(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertPredictionHeadTransform, self).__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class TFBertLMPredictionHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertLMPredictionHead, self).__init__(**kwargs) + self.vocab_size = config.vocab_size + self.transform = TFBertPredictionHeadTransform(config, name='transform') + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name='decoder') + + def build(self, input_shape): + self.bias = self.add_weight(shape=(self.vocab_size,), + initializer='zeros', + trainable=True, + name='bias') + + def call(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class TFBertMLMHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertMLMHead, self).__init__(**kwargs) + self.predictions = TFBertLMPredictionHead(config, name='predictions') + + def call(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class TFBertNSPHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertNSPHead, self).__init__(**kwargs) + self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship') + + def call(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class TFBertMainLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super(TFBertMainLayer, self).__init__(**kwargs) + self.num_hidden_layers = config.num_hidden_layers + + self.embeddings = TFBertEmbeddings(config, name='embeddings') + self.encoder = TFBertEncoder(config, name='encoder') + self.pooler = TFBertPooler(config, name='pooler') + + # self.apply(self.init_weights) # TODO check weights initialization + + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + raise NotImplementedError + + def call(self, inputs, training=False): + if not isinstance(inputs, (dict, tuple, list)): + input_ids = inputs + attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None + elif isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else None + token_type_ids = inputs[2] if len(inputs) > 2 else None + position_ids = inputs[3] if len(inputs) > 3 else None + head_mask = inputs[4] if len(inputs) > 4 else None + assert len(inputs) <= 5, "Too many inputs." + else: + input_ids = inputs.pop('input_ids') + attention_mask = inputs.pop('attention_mask', None) + token_type_ids = inputs.pop('token_type_ids', None) + position_ids = inputs.pop('position_ids', None) + head_mask = inputs.pop('head_mask', None) + assert len(inputs) == 0, "Unexpected inputs detected: {}. Check inputs dict key names.".format(list(inputs.keys())) + + if attention_mask is None: + attention_mask = tf.fill(tf.shape(input_ids), 1) + if token_type_ids is None: + token_type_ids = tf.fill(tf.shape(input_ids), 0) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if not head_mask is None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training) + encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + +class TFBertPreTrainedModel(TFPreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + config_class = BertConfig + pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP + load_pt_weights = load_pt_weights_in_bert + base_model_prefix = "bert" + + def __init__(self, *inputs, **kwargs): + super(TFBertPreTrainedModel, self).__init__(*inputs, **kwargs) + + def init_weights(self, module): + """ Initialize the weights. + """ + raise NotImplementedError + + +BERT_START_DOCSTRING = r""" The BERT model was proposed in + `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ + by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer + pre-trained using a combination of masked language modeling objective and next sentence prediction + on a large corpus comprising the Toronto Book Corpus and Wikipedia. + + This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and + refer to the TF 2.0 documentation for all matter related to general usage and behavior. + + .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`: + https://arxiv.org/abs/1810.04805 + + .. _`tf.keras.Model`: + https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model + + Important note on the model inputs: + The inputs of the TF 2.0 models are slightly different from the PyTorch ones since + TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor. + More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`. + There are three possibilities to gather and feed the inputs to the model: + + - a single Tensor with input_ids only and nothing else: `model(inputs_ids) + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associaed to the input names given in the docstring: + `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})` + + Parameters: + config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + + Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. + + Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. +""" + +@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +class TFBertModel(TFBertPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = TFBertModel.from_pretrained('bert-base-uncased') + input_ids = tf.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + def __init__(self, config): + super(TFBertModel, self).__init__(config) + self.bert = TFBertMainLayer(config, name='bert') + + def call(self, inputs, training=False): + outputs = self.bert(inputs, training=training) + return outputs + + +@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: + a `masked language modeling` head and a `next sentence prediction (classification)` head. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +class TFBertForPreTraining(TFBertPreTrainedModel): + r""" + **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the masked language modeling loss. + Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = TFBertForPreTraining.from_pretrained('bert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + prediction_scores, seq_relationship_scores = outputs[:2] + + """ + def __init__(self, config): + super(TFBertForPreTraining, self).__init__(config) + + self.bert = TFBertMainLayer(config, name='bert') + self.cls_mlm = TFBertMLMHead(config, name='cls_mlm') + self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') + + # self.apply(self.init_weights) # TODO check added weights initialization + self.tie_weights() + + def tie_weights(self): + """ Make sure we are sharing the input and output embeddings. + """ + pass # TODO add weights tying + + def call(self, inputs, training=False): + outputs = self.bert(inputs, training=training) + + sequence_output, pooled_output = outputs[:2] + prediction_scores = self.cls_mlm(sequence_output) + seq_relationship_score = self.cls_nsp(pooled_output) + + outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + + # if masked_lm_labels is not None and next_sentence_label is not None: + # loss_fct = CrossEntropyLoss(ignore_index=-1) + # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + # next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + # total_loss = masked_lm_loss + next_sentence_loss + # outputs = (total_loss,) + outputs + # TODO add example with losses using model.compile and a dictionary of losses (give names to the output layers) + + return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +class TFBertForMaskedLM(TFBertPreTrainedModel): + r""" + **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the masked language modeling loss. + Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = TFBertForMaskedLM.from_pretrained('bert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, masked_lm_labels=input_ids) + loss, prediction_scores = outputs[:2] + + """ + def __init__(self, config): + super(TFBertForMaskedLM, self).__init__(config) + + self.bert = TFBertMainLayer(config, name='bert') + self.cls_mlm = TFBertMLMHead(config, name='cls_mlm') + + # self.apply(self.init_weights) + self.tie_weights() + + def tie_weights(self): + """ Make sure we are sharing the input and output embeddings. + """ + pass # TODO add weights tying + + def call(self, inputs, training=False): + outputs = self.bert(inputs, training=training) + + sequence_output = outputs[0] + prediction_scores = self.cls_mlm(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + # if masked_lm_labels is not None: + # loss_fct = CrossEntropyLoss(ignore_index=-1) + # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + # outputs = (masked_lm_loss,) + outputs + # TODO example with losses + + return outputs # prediction_scores, (hidden_states), (attentions) + + +@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +class TFBertForNextSentencePrediction(TFBertPreTrainedModel): + r""" + **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Next sequence prediction (classification) loss. + **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + seq_relationship_scores = outputs[0] + + """ + def __init__(self, config): + super(TFBertForNextSentencePrediction, self).__init__(config) + + self.bert = TFBertMainLayer(config, name='bert') + self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') + + # self.apply(self.init_weights) + + def call(self, inputs, training=False): + outputs = self.bert(inputs, training=training) + + pooled_output = outputs[1] + seq_relationship_score = self.cls_nsp(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + # if next_sentence_label is not None: + # loss_fct = CrossEntropyLoss(ignore_index=-1) + # next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + # outputs = (next_sentence_loss,) + outputs + + return outputs # seq_relationship_score, (hidden_states), (attentions) diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py new file mode 100644 index 00000000000..442d9dbe445 --- /dev/null +++ b/pytorch_transformers/modeling_tf_utils.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF general model utils.""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging +import os + +import tensorflow as tf + +from .configuration_utils import PretrainedConfig +from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME + +logger = logging.getLogger(__name__) + + +class TFPreTrainedModel(tf.keras.Model): + r""" Base class for all TF models. + + :class:`~pytorch_transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models + as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture. + - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values. + - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: + + - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`, + - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`, + - ``path``: a path (string) to the TensorFlow checkpoint. + + - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + """ + config_class = None + pretrained_model_archive_map = {} + load_pt_weights = lambda model, config, path: None + base_model_prefix = "" + + def __init__(self, config, *inputs, **kwargs): + super(TFPreTrainedModel, self).__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " + "To create a model from a pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + # Save config in model + self.config = config + + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): + """ Build a resized Embedding Module from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + + Args: + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: return the provided token Embedding Module. + Return: ``torch.nn.Embeddings`` + Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None + """ + raise NotImplementedError + + def _tie_or_clone_weights(self, first_module, second_module): + """ Tie or clone module weights depending of weither we are using TorchScript or not + """ + raise NotImplementedError + + def resize_token_embeddings(self, new_num_tokens=None): + """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. + Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + + new_num_tokens: (`optional`) int: + New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. + If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. + + Return: ``torch.nn.Embeddings`` + Pointer to the input tokens Embeddings Module of the model + """ + raise NotImplementedError + + def prune_heads(self, heads_to_prune): + """ Prunes heads of the base model. + + Arguments: + + heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). + """ + raise NotImplementedError + + def save_pretrained(self, save_directory): + """ Save a model and its configuration file to a directory, so that it + can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method. + """ + raise NotImplementedError + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with ``model.train()`` + + The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. + It is up to you to train those weights with a downstream fine-tuning task. + + The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. + + Parameters: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + from_pt: (`optional`) boolean, default False: + Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') + model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) + + """ + config = kwargs.pop('config', None) + cache_dir = kwargs.pop('cache_dir', None) + from_pt = kwargs.pop('from_pt', False) + force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) + output_loading_info = kwargs.pop('output_loading_info', False) + + # Load config + if config is None: + config, model_kwargs = cls.config_class.from_pretrained( + pretrained_model_name_or_path, *model_args, + cache_dir=cache_dir, return_unused_kwargs=True, + force_download=force_download, + **kwargs + ) + else: + model_kwargs = kwargs + + # Load model + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + if from_pt: + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME) + else: + archive_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) + except EnvironmentError: + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained weights.".format( + archive_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_model_archive_map.keys()), + archive_file)) + return None + if resolved_archive_file == archive_file: + logger.info("loading weights file {}".format(archive_file)) + else: + logger.info("loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file)) + + # Instantiate model. + model = cls(config, *model_args, **model_kwargs) + + if from_pt: + # Load from a PyTorch checkpoint + return cls.load_pt_weights(model, config, resolved_archive_file) + + inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) + ret = model(inputs, training=False) # build the network with dummy inputs + + # 'by_name' allow us to do transfer learning by skipping/adding layers + # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 + model.load_weights(resolved_archive_file, by_name=True) + + ret = model(inputs, training=False) # Make sure restore ops are run + + # if hasattr(model, 'tie_weights'): + # model.tie_weights() # TODO make sure word embedding weights are still tied + + if output_loading_info: + loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} + return model, loading_info + + return model diff --git a/pytorch_transformers/tests/modeling_tf_common_test.py b/pytorch_transformers/tests/modeling_tf_common_test.py new file mode 100644 index 00000000000..d432ab5fdf8 --- /dev/null +++ b/pytorch_transformers/tests/modeling_tf_common_test.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2019 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import shutil +import json +import random +import uuid + +import unittest +import logging + +import tensorflow as tf + +from pytorch_transformers import TFPreTrainedModel +# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if '_range' in key or '_std' in key: + setattr(configs_no_init, key, 0.0) + return configs_no_init + +class TFCommonTestCases: + + class TFCommonModelTester(unittest.TestCase): + + model_tester = None + all_model_classes = () + test_torchscript = True + test_pruning = True + test_resize_embeddings = True + + def test_initialization(self): + pass + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # configs_no_init = _config_zero_init(config) + # for model_class in self.all_model_classes: + # model = model_class(config=configs_no_init) + # for name, param in model.named_parameters(): + # if param.requires_grad: + # self.assertIn(param.data.mean().item(), [0.0, 1.0], + # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) + + + def test_attention_outputs(self): + pass + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # for model_class in self.all_model_classes: + # config.output_attentions = True + # config.output_hidden_states = False + # model = model_class(config) + # model.eval() + # outputs = model(**inputs_dict) + # attentions = outputs[-1] + # self.assertEqual(model.config.output_attentions, True) + # self.assertEqual(model.config.output_hidden_states, False) + # self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + # self.assertListEqual( + # list(attentions[0].shape[-3:]), + # [self.model_tester.num_attention_heads, + # self.model_tester.seq_length, + # self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) + # out_len = len(outputs) + + # # Check attention is always last and order is fine + # config.output_attentions = True + # config.output_hidden_states = True + # model = model_class(config) + # model.eval() + # outputs = model(**inputs_dict) + # self.assertEqual(out_len+1, len(outputs)) + # self.assertEqual(model.config.output_attentions, True) + # self.assertEqual(model.config.output_hidden_states, True) + + # attentions = outputs[-1] + # self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + # self.assertListEqual( + # list(attentions[0].shape[-3:]), + # [self.model_tester.num_attention_heads, + # self.model_tester.seq_length, + # self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) + + + def test_headmasking(self): + pass + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # config.output_attentions = True + # config.output_hidden_states = True + # configs_no_init = _config_zero_init(config) # To be sure we have no Nan + # for model_class in self.all_model_classes: + # model = model_class(config=configs_no_init) + # model.eval() + + # # Prepare head_mask + # # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior) + # head_mask = torch.ones(self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads) + # head_mask[0, 0] = 0 + # head_mask[-1, :-1] = 0 + # head_mask.requires_grad_(requires_grad=True) + # inputs = inputs_dict.copy() + # inputs['head_mask'] = head_mask + + # outputs = model(**inputs) + + # # Test that we can get a gradient back for importance score computation + # output = sum(t.sum() for t in outputs[0]) + # output = output.sum() + # output.backward() + # multihead_outputs = head_mask.grad + + # attentions = outputs[-1] + # hidden_states = outputs[-2] + + # # Remove Nan + + # self.assertIsNotNone(multihead_outputs) + # self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers) + # self.assertAlmostEqual( + # attentions[0][..., 0, :, :].flatten().sum().item(), 0.0) + # self.assertNotEqual( + # attentions[0][..., -1, :, :].flatten().sum().item(), 0.0) + # self.assertNotEqual( + # attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) + # self.assertAlmostEqual( + # attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0) + # self.assertNotEqual( + # attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0) + + + def test_head_pruning(self): + pass + # if not self.test_pruning: + # return + + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # for model_class in self.all_model_classes: + # config.output_attentions = True + # config.output_hidden_states = False + # model = model_class(config=config) + # model.eval() + # heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), + # -1: [0]} + # model.prune_heads(heads_to_prune) + # outputs = model(**inputs_dict) + + # attentions = outputs[-1] + + # self.assertEqual( + # attentions[0].shape[-3], 1) + # self.assertEqual( + # attentions[1].shape[-3], self.model_tester.num_attention_heads) + # self.assertEqual( + # attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) + + + def test_hidden_states_output(self): + pass + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # for model_class in self.all_model_classes: + # config.output_hidden_states = True + # config.output_attentions = False + # model = model_class(config) + # model.eval() + # outputs = model(**inputs_dict) + # hidden_states = outputs[-1] + # self.assertEqual(model.config.output_attentions, False) + # self.assertEqual(model.config.output_hidden_states, True) + # self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) + # self.assertListEqual( + # list(hidden_states[0].shape[-2:]), + # [self.model_tester.seq_length, self.model_tester.hidden_size]) + + + def test_resize_tokens_embeddings(self): + pass + # original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # if not self.test_resize_embeddings: + # return + + # for model_class in self.all_model_classes: + # config = copy.deepcopy(original_config) + # model = model_class(config) + + # model_vocab_size = config.vocab_size + # # Retrieve the embeddings and clone theme + # model_embed = model.resize_token_embeddings(model_vocab_size) + # cloned_embeddings = model_embed.weight.clone() + + # # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + # model_embed = model.resize_token_embeddings(model_vocab_size + 10) + # self.assertEqual(model.config.vocab_size, model_vocab_size + 10) + # # Check that it actually resizes the embeddings matrix + # self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + + # # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + # model_embed = model.resize_token_embeddings(model_vocab_size - 15) + # self.assertEqual(model.config.vocab_size, model_vocab_size - 15) + # # Check that it actually resizes the embeddings matrix + # self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + + # # Check that adding and removing tokens has not modified the first part of the embedding matrix. + # models_equal = True + # for p1, p2 in zip(cloned_embeddings, model_embed.weight): + # if p1.data.ne(p2.data).sum() > 0: + # models_equal = False + + # self.assertTrue(models_equal) + + + def test_tie_model_weights(self): + pass + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # def check_same_values(layer_1, layer_2): + # equal = True + # for p1, p2 in zip(layer_1.weight, layer_2.weight): + # if p1.data.ne(p2.data).sum() > 0: + # equal = False + # return equal + + # for model_class in self.all_model_classes: + # if not hasattr(model_class, 'tie_weights'): + # continue + + # config.torchscript = True + # model_not_tied = model_class(config) + # params_not_tied = list(model_not_tied.parameters()) + + # config_tied = copy.deepcopy(config) + # config_tied.torchscript = False + # model_tied = model_class(config_tied) + # params_tied = list(model_tied.parameters()) + + # # Check that the embedding layer and decoding layer are the same in size and in value + # self.assertGreater(len(params_not_tied), len(params_tied)) + + # # Check that after resize they remain tied. + # model_tied.resize_token_embeddings(config.vocab_size + 10) + # params_tied_2 = list(model_tied.parameters()) + # self.assertGreater(len(params_not_tied), len(params_tied)) + # self.assertEqual(len(params_tied_2), len(params_tied)) + + +def ids_tensor(shape, vocab_size, rng=None, name=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return tf.constant(values, shape=shape) + + +class TFModelUtilsTest(unittest.TestCase): + def test_model_from_pretrained(self): + pass + # logging.basicConfig(level=logging.INFO) + # for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + # config = BertConfig.from_pretrained(model_name) + # self.assertIsNotNone(config) + # self.assertIsInstance(config, PretrainedConfig) + + # model = BertModel.from_pretrained(model_name) + # model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True) + # self.assertIsNotNone(model) + # self.assertIsInstance(model, PreTrainedModel) + # for value in loading_info.values(): + # self.assertEqual(len(value), 0) + + # config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) + # model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) + # self.assertEqual(model.config.output_attentions, True) + # self.assertEqual(model.config.output_hidden_states, True) + # self.assertEqual(model.config, config) + + +if __name__ == "__main__": + unittest.main() diff --git a/pytorch_transformers/tests/modeling_tf_test.py b/pytorch_transformers/tests/modeling_tf_test.py new file mode 100644 index 00000000000..ee2580bd9e3 --- /dev/null +++ b/pytorch_transformers/tests/modeling_tf_test.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest + +import tensorflow as tf + +from pytorch_transformers import (BertConfig) +from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP + +from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester + + +class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): + + all_model_classes = (TFBertModel,) + # BertForMaskedLM, BertForNextSentencePrediction, + # BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, + # BertForTokenClassification) + + class TFBertModelTester(object): + + def __init__(self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = BertConfig( + vocab_size_or_config_json_file=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def check_loss_output(self, result): + self.parent.assertListEqual( + list(result["loss"].size()), + []) + + def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = TFBertModel(config=config) + # model.eval() + inputs = {'input_ids': input_ids, + 'attention_mask': input_mask, + 'token_type_ids': token_type_ids} + sequence_output, pooled_output = model(inputs) + + inputs = [input_ids, input_mask] + sequence_output, pooled_output = model(inputs) + + sequence_output, pooled_output = model(input_ids) + + result = { + "sequence_output": sequence_output.numpy(), + "pooled_output": pooled_output.numpy(), + } + self.parent.assertListEqual( + list(result["sequence_output"].shape), + [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size]) + + + def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # model = BertForMaskedLM(config=config) + # model.eval() + # loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels) + # result = { + # "loss": loss, + # "prediction_scores": prediction_scores, + # } + # self.parent.assertListEqual( + # list(result["prediction_scores"].size()), + # [self.batch_size, self.seq_length, self.vocab_size]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # model = BertForNextSentencePrediction(config=config) + # model.eval() + # loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels) + # result = { + # "loss": loss, + # "seq_relationship_score": seq_relationship_score, + # } + # self.parent.assertListEqual( + # list(result["seq_relationship_score"].size()), + # [self.batch_size, 2]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # model = BertForPreTraining(config=config) + # model.eval() + # loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) + # result = { + # "loss": loss, + # "prediction_scores": prediction_scores, + # "seq_relationship_score": seq_relationship_score, + # } + # self.parent.assertListEqual( + # list(result["prediction_scores"].size()), + # [self.batch_size, self.seq_length, self.vocab_size]) + # self.parent.assertListEqual( + # list(result["seq_relationship_score"].size()), + # [self.batch_size, 2]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # model = BertForQuestionAnswering(config=config) + # model.eval() + # loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) + # result = { + # "loss": loss, + # "start_logits": start_logits, + # "end_logits": end_logits, + # } + # self.parent.assertListEqual( + # list(result["start_logits"].size()), + # [self.batch_size, self.seq_length]) + # self.parent.assertListEqual( + # list(result["end_logits"].size()), + # [self.batch_size, self.seq_length]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # config.num_labels = self.num_labels + # model = BertForSequenceClassification(config) + # model.eval() + # loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.num_labels]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # config.num_labels = self.num_labels + # model = BertForTokenClassification(config=config) + # model.eval() + # loss, logits = model(input_ids, token_type_ids, input_mask, token_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.seq_length, self.num_labels]) + # self.check_loss_output(result) + + + def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + pass + # config.num_choices = self.num_choices + # model = BertForMultipleChoice(config=config) + # model.eval() + # multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # loss, logits = model(multiple_choice_inputs_ids, + # multiple_choice_token_type_ids, + # multiple_choice_input_mask, + # choice_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.num_choices]) + # self.check_loss_output(result) + + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, token_type_ids, input_mask, + sequence_labels, token_labels, choice_labels) = config_and_inputs + inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = TFBertModelTest.TFBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_bert_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs) + + def test_for_next_sequence_prediction(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) + + @pytest.mark.slow + def test_model_from_pretrained(self): + cache_dir = "/tmp/pytorch_transformers_test/" + for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) + shutil.rmtree(cache_dir) + self.assertIsNotNone(model) + +if __name__ == "__main__": + unittest.main()