mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Making TF Lxmert model compliant with AMP (#10257)
* Fix AMP * Rework cast * Apply style
This commit is contained in:
parent
d27b28d958
commit
2fc6284f04
@ -295,11 +295,12 @@ class TFLxmertAttention(tf.keras.layers.Layer):
|
||||
attention_scores = tf.matmul(
|
||||
query_layer, key_layer, transpose_b=True
|
||||
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
|
||||
dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
||||
# Apply the attention mask is (precomputed for all layers in TFLxmertModel call() function)
|
||||
attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
@ -721,6 +722,11 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
# Positional Word Embeddings
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
|
||||
)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
@ -734,8 +740,10 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
||||
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||
|
||||
if inputs["visual_attention_mask"] is not None:
|
||||
extended_visual_attention_mask = tf.reshape(
|
||||
@ -745,16 +753,13 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1
|
||||
)
|
||||
|
||||
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
|
||||
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
||||
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype)
|
||||
extended_visual_attention_mask = tf.multiply(
|
||||
tf.subtract(one_cst, extended_visual_attention_mask), ten_thousand_cst
|
||||
)
|
||||
else:
|
||||
extended_visual_attention_mask = None
|
||||
|
||||
# Positional Word Embeddings
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"]
|
||||
)
|
||||
|
||||
# Run Lxmert encoder
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
@ -706,10 +706,6 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Lxmert float16 compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user