mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add weights tying, attention and hidden states output tests
This commit is contained in:
parent
04d2006f28
commit
600a42329b
@ -141,7 +141,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFBertEmbeddings, self).__init__(**kwargs)
|
||||
self.word_embeddings = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size, name='word_embeddings')
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
|
||||
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')
|
||||
|
||||
@ -150,8 +152,44 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
with tf.name_scope("word_embeddings"):
|
||||
# Create and initialize weights. The random normal initializer was chosen
|
||||
# arbitrarily, and works well.
|
||||
self.word_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=tf.random_normal_initializer(
|
||||
mean=0., stddev=self.hidden_size**-0.5))
|
||||
super(TFBertEmbeddings, self).build(input_shape)
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
# Create binary mask of size [batch_size, length]
|
||||
input_ids, position_ids, token_type_ids = inputs
|
||||
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
@ -160,7 +198,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
words_embeddings = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
@ -170,6 +208,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
length = tf.shape(inputs)[1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
|
||||
|
||||
class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
@ -448,8 +501,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
self.encoder = TFBertEncoder(config, name='encoder')
|
||||
self.pooler = TFBertPooler(config, name='pooler')
|
||||
|
||||
# self.apply(self.init_weights) # TODO check weights initialization
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -692,22 +743,14 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
super(TFBertForPreTraining, self).__init__(config)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
|
||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
"""
|
||||
pass # TODO add weights tying
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores = self.cls_mlm(sequence_output)
|
||||
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
|
||||
seq_relationship_score = self.cls_nsp(pooled_output)
|
||||
|
||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
||||
@ -751,21 +794,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
super(TFBertForMaskedLM, self).__init__(config)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
|
||||
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
"""
|
||||
pass # TODO add weights tying
|
||||
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.bert(inputs, training=training)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls_mlm(sequence_output)
|
||||
prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
|
@ -64,7 +64,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
self.config = config
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
|
||||
@ -77,12 +77,25 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# if new_num_tokens is None:
|
||||
# return old_embeddings
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
# if old_num_tokens == new_num_tokens:
|
||||
# return old_embeddings
|
||||
|
||||
# # Build new embeddings
|
||||
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
# new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
# # initialize all new embeddings (in particular added tokens)
|
||||
# self._init_weights(new_embeddings)
|
||||
|
||||
# # Copy word embeddings from the previous weights
|
||||
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
# return new_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
|
@ -64,44 +64,40 @@ class TFCommonTestCases:
|
||||
|
||||
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# for model_class in self.all_model_classes:
|
||||
# config.output_attentions = True
|
||||
# config.output_hidden_states = False
|
||||
# model = model_class(config)
|
||||
# model.eval()
|
||||
# outputs = model(**inputs_dict)
|
||||
# attentions = outputs[-1]
|
||||
# self.assertEqual(model.config.output_attentions, True)
|
||||
# self.assertEqual(model.config.output_hidden_states, False)
|
||||
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
# self.assertListEqual(
|
||||
# list(attentions[0].shape[-3:]),
|
||||
# [self.model_tester.num_attention_heads,
|
||||
# self.model_tester.seq_length,
|
||||
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
# out_len = len(outputs)
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
out_len = len(outputs)
|
||||
|
||||
# # Check attention is always last and order is fine
|
||||
# config.output_attentions = True
|
||||
# config.output_hidden_states = True
|
||||
# model = model_class(config)
|
||||
# model.eval()
|
||||
# outputs = model(**inputs_dict)
|
||||
# self.assertEqual(out_len+1, len(outputs))
|
||||
# self.assertEqual(model.config.output_attentions, True)
|
||||
# self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
# attentions = outputs[-1]
|
||||
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
# self.assertListEqual(
|
||||
# list(attentions[0].shape[-3:]),
|
||||
# [self.model_tester.num_attention_heads,
|
||||
# self.model_tester.seq_length,
|
||||
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
# Check attention is always last and order is fine
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
self.assertEqual(out_len+1, len(outputs))
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
|
||||
def test_headmasking(self):
|
||||
pass
|
||||
@ -178,22 +174,20 @@ class TFCommonTestCases:
|
||||
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# for model_class in self.all_model_classes:
|
||||
# config.output_hidden_states = True
|
||||
# config.output_attentions = False
|
||||
# model = model_class(config)
|
||||
# model.eval()
|
||||
# outputs = model(**inputs_dict)
|
||||
# hidden_states = outputs[-1]
|
||||
# self.assertEqual(model.config.output_attentions, False)
|
||||
# self.assertEqual(model.config.output_hidden_states, True)
|
||||
# self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
# self.assertListEqual(
|
||||
# list(hidden_states[0].shape[-2:]),
|
||||
# [self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = False
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_attentions, False)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
|
Loading…
Reference in New Issue
Block a user