WIP XLNet

This commit is contained in:
thomwolf 2019-09-10 12:17:18 +02:00
parent f851fb55ca
commit 32aabe8c33
7 changed files with 1540 additions and 68 deletions

View File

@ -95,7 +95,7 @@ except (ImportError, AssertionError):
if _tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead)
@ -107,7 +107,7 @@ if _tf_available:
load_bert_pt_weights_in_tf2,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings,
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)

View File

@ -54,6 +54,7 @@ class PretrainedConfig(object):
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
def save_pretrained(self, save_directory):

View File

@ -28,7 +28,8 @@ from io import open
import numpy as np
import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list
from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list)
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings
@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '')
@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
weight_value_tuples.append((symbolic_weight, array))
state_dict.pop(name)
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys()))
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
class TFGPT2Embeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, **kwargs):
super(TFGPT2Embeddings, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
def build(self, input_shape):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self.weight = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(TFGPT2Embeddings, self).build(input_shape)
def call(self, inputs, mode="embedding"):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(inputs)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids):
"""Applies embedding based on inputs tensor."""
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims = shape_list(inputs)[:-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, first_dims + [self.vocab_size])
class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.wte = TFGPT2Embeddings(config, name='wte')
self.wte = TFSharedEmbeddings(config.vocab_size, config.hidden_size, name='wte')
self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name='h_{}'.format(i)) for i in range(config.n_layer)]
self.h = [TFBlock(config.n_ctx,
config,
scale=True,
name='h_{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def _resize_token_embeddings(self, new_num_tokens):

View File

@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
return x
class TFSharedEmbeddings(tf.keras.layers.Layer):
"""Construct shared token embeddings.
"""
def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
super(TFSharedEmbeddings, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
initializer_range = self.hidden_size**-0.5 if self.initializer_range is None else self.initializer_range
self.weight = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
mean=0., stddev=initializer_range))
super(TFSharedEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding"):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(inputs)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids):
"""Applies embedding based on inputs tensor."""
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims = shape_list(inputs)[:-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, first_dims + [self.vocab_size])
class TFSequenceSummary(tf.keras.layers.Layer):
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:

File diff suppressed because it is too large Load Diff

View File

@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied))
def ids_tensor(shape, vocab_size, rng=None, name=None):
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return tf.constant(values, shape=shape)
return tf.constant(values, shape=shape, dtype=dtype)
class TFModelUtilsTest(unittest.TestCase):

View File

@ -0,0 +1,341 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
from pytorch_transformers import XLNetConfig, is_tf_available
if is_tf_available():
import tensorflow as tf
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
# XLNetLMHeadModel,
# XLNetForSequenceClassification, XLNetForQuestionAnswering)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes=(TFXLNetModel, ) if is_tf_available() else ()
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else ()
test_pruning = False
class TFXLNetModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
mem_len=10,
clamp_len=-1,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
same_length=False,
initializer_range=0.05,
seed=1,
type_vocab_size=2,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
# self.key_len = seq_length + mem_len
self.clamp_len = clamp_len
self.reuse_len = reuse_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.d_inner = d_inner
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
self.initializer_range = initializer_range
self.seed = seed
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
perm_mask_last = tf.ones((self.batch_size, self.seq_length + 1, 1), dtype=tf.float32)
perm_mask = tf.concat([perm_mask, perm_mask_last], axis=-1)
# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = tf.zeros((self.batch_size, 1, self.seq_length), dtype=torch.float32)
target_mapping_last = tf.ones((self.batch_size, 1, 1), dtype=torch.float32)
target_mapping = tf.concat([target_mapping, target_mapping_last], axis=-1)
# target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None
lm_labels = None
is_impossible_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data,
initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
def set_seed(self):
random.seed(self.seed)
tf.random.set_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = TFXLNetModel(config)
inputs = {'input_ids': input_ids,
'input_mask': input_mask,
'token_type_ids': token_type_ids}
_, _ = model(inputs)
inputs = [input_ids, input_mask]
outputs, mems_1 = model(inputs)
result = {
"mems_1": [mem.numpy() for m in mems_1],
"outputs": outputs.numpy(),
}
self.parent.assertListEqual(
list(result["outputs"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetLMHeadModel(config)
# model.eval()
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
# result = {
# "loss_1": loss_1,
# "mems_1": mems_1,
# "all_logits_1": all_logits_1,
# "loss_2": loss_2,
# "mems_2": mems_2,
# "all_logits_2": all_logits_2,
# }
# self.parent.assertListEqual(
# list(result["loss_1"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_1"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# self.parent.assertListEqual(
# list(result["loss_2"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_2"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_2"]),
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetForQuestionAnswering(config)
# model.eval()
# outputs = model(input_ids_1)
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels,
# p_mask=input_mask)
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels)
# total_loss, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels)
# total_loss, mems = outputs
# result = {
# "loss": total_loss,
# "start_top_log_probs": start_top_log_probs,
# "start_top_index": start_top_index,
# "end_top_log_probs": end_top_log_probs,
# "end_top_index": end_top_index,
# "cls_logits": cls_logits,
# "mems": mems,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["start_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["start_top_index"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["end_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["end_top_index"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["cls_logits"].size()),
# [self.batch_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetForSequenceClassification(config)
# model.eval()
# logits, mems_1 = model(input_ids_1)
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
# result = {
# "loss": loss,
# "mems_1": mems_1,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.type_sequence_label_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels,
sequence_labels, is_impossible_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
def setUp(self):
self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_xlnet_base_model(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_lm_head(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
def test_xlnet_sequence_classif(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
def test_xlnet_qa(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()