Avoid using tf.tile in embeddings for TF models (#14735)

* avoid tf.tile in embeddings

* remove more tf.tile in embeddings

* clean

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2021-12-13 18:30:46 +01:00 committed by GitHub
parent 6ac0fac85a
commit 15a9d01519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 23 additions and 57 deletions

View File

@ -122,7 +122,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
self.embedding_size = config.embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -183,9 +182,8 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -141,7 +141,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -201,9 +200,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -75,7 +75,6 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
self.embedding_size = config.embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -136,9 +135,8 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -728,7 +728,6 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
self.max_position_embeddings = config.max_position_embeddings
self.position_biased_input = getattr(config, "position_biased_input", True)
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
if self.embedding_size != config.hidden_size:
self.embed_proj = tf.keras.layers.Dense(config.hidden_size, bias=False)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
@ -795,7 +794,6 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
final_embeddings = inputs_embeds
if self.position_biased_input:
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
final_embeddings += position_embeds
if self.type_vocab_size > 0:
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)

View File

@ -845,7 +845,6 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
self.max_position_embeddings = config.max_position_embeddings
self.position_biased_input = getattr(config, "position_biased_input", True)
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
if self.embedding_size != config.hidden_size:
self.embed_proj = tf.keras.layers.Dense(config.hidden_size, bias=False)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
@ -912,7 +911,6 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
final_embeddings = inputs_embeds
if self.position_biased_input:
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
final_embeddings += position_embeds
if self.type_vocab_size > 0:
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)

View File

@ -77,8 +77,6 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.dim = config.dim
self.initializer_range = config.initializer_range
self.max_position_embeddings = config.max_position_embeddings
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.dropout)
@ -117,8 +115,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds])
final_embeddings = inputs_embeds + position_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -481,7 +481,6 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
self.embedding_size = config.embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -542,9 +541,8 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -68,7 +68,6 @@ class TFLayoutLMEmbeddings(tf.keras.layers.Layer):
self.max_position_embeddings = config.max_position_embeddings
self.max_2d_position_embeddings = config.max_2d_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -168,20 +167,17 @@ class TFLayoutLMEmbeddings(tf.keras.layers.Layer):
w_position_embeddings = tf.gather(self.w_position_embeddings, bbox[:, :, 2] - bbox[:, :, 0])
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(
inputs=[
inputs_embeds,
position_embeds,
token_type_embeds,
left_position_embeddings,
upper_position_embeddings,
right_position_embeddings,
lower_position_embeddings,
h_position_embeddings,
w_position_embeddings,
]
final_embeddings = (
inputs_embeds
+ position_embeds
+ token_type_embeds
+ left_position_embeddings
+ upper_position_embeddings
+ right_position_embeddings
+ lower_position_embeddings
+ h_position_embeddings
+ w_position_embeddings
)
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -482,7 +482,6 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -559,11 +558,10 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
)
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -188,7 +188,6 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -235,9 +234,8 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -121,7 +121,6 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
self.type_vocab_size = config.type_vocab_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.embedding_transformation = tf.keras.layers.Dense(config.hidden_size, name="embedding_transformation")
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
@ -196,9 +195,8 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -98,7 +98,6 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -155,10 +154,9 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
)
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds])
final_embeddings = inputs_embeds + position_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -79,7 +79,6 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
self.input_embedding_size = config.input_embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -138,9 +137,8 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -87,7 +87,6 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -164,11 +163,10 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
)
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -140,7 +140,6 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer):
self.type_vocab_size = config.type_vocab_size
self.embedding_size = config.embedding_size
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -186,7 +185,7 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, token_type_embeds])
final_embeddings = inputs_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)

View File

@ -83,7 +83,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
@ -142,9 +141,8 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)