add weights tying, attention and hidden states output tests

This commit is contained in:
thomwolf 2019-09-05 12:02:14 +02:00
parent 04d2006f28
commit 600a42329b
3 changed files with 121 additions and 79 deletions

View File

@ -141,7 +141,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
"""
def __init__(self, config, **kwargs):
super(TFBertEmbeddings, self).__init__(**kwargs)
self.word_embeddings = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size, name='word_embeddings')
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')
@ -150,8 +152,44 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape):
"""Build shared word embedding layer """
with tf.name_scope("word_embeddings"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(TFBertEmbeddings, self).build(input_shape)
@tf.function
def call(self, inputs, training=False):
def call(self, inputs, mode="embedding", training=False):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(inputs, training=training)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor."""
# Create binary mask of size [batch_size, length]
input_ids, position_ids, token_type_ids = inputs
seq_length = tf.shape(input_ids)[1]
@ -160,7 +198,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0)
words_embeddings = self.word_embeddings(input_ids)
words_embeddings = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
@ -170,6 +208,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
embeddings = self.dropout(embeddings)
return embeddings
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size = tf.shape(inputs)[0]
length = tf.shape(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
class TFBertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -448,8 +501,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler')
# self.apply(self.init_weights) # TODO check weights initialization
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -692,22 +743,14 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
super(TFBertForPreTraining, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.cls_mlm(sequence_output)
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
@ -751,21 +794,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super(TFBertForMaskedLM, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
@tf.function
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
sequence_output = outputs[0]
prediction_scores = self.cls_mlm(sequence_output)
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

View File

@ -64,7 +64,7 @@ class TFPreTrainedModel(tf.keras.Model):
self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
@ -77,12 +77,25 @@ class TFPreTrainedModel(tf.keras.Model):
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
raise NotImplementedError
# if new_num_tokens is None:
# return old_embeddings
def _tie_or_clone_weights(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
raise NotImplementedError
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
# if old_num_tokens == new_num_tokens:
# return old_embeddings
# # Build new embeddings
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
# new_embeddings.to(old_embeddings.weight.device)
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy word embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# return new_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.

View File

@ -64,44 +64,40 @@ class TFCommonTestCases:
def test_attention_outputs(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_attentions = True
# config.output_hidden_states = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# attentions = outputs[-1]
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, False)
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# out_len = len(outputs)
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
out_len = len(outputs)
# # Check attention is always last and order is fine
# config.output_attentions = True
# config.output_hidden_states = True
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# self.assertEqual(out_len+1, len(outputs))
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, True)
# attentions = outputs[-1]
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# Check attention is always last and order is fine
config.output_attentions = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(inputs_dict)
self.assertEqual(out_len+1, len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
def test_headmasking(self):
pass
@ -178,22 +174,20 @@ class TFCommonTestCases:
def test_hidden_states_output(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_hidden_states = True
# config.output_attentions = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# hidden_states = outputs[-1]
# self.assertEqual(model.config.output_attentions, False)
# self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
# self.assertListEqual(
# list(hidden_states[0].shape[-2:]),
# [self.model_tester.seq_length, self.model_tester.hidden_size])
for model_class in self.all_model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
outputs = model(inputs_dict)
hidden_states = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size])
def test_resize_tokens_embeddings(self):