diff --git a/src/transformers/models/lxmert/modeling_tf_lxmert.py b/src/transformers/models/lxmert/modeling_tf_lxmert.py index d787216f957..8049da1cfdb 100644 --- a/src/transformers/models/lxmert/modeling_tf_lxmert.py +++ b/src/transformers/models/lxmert/modeling_tf_lxmert.py @@ -295,11 +295,12 @@ class TFLxmertAttention(tf.keras.layers.Layer): attention_scores = tf.matmul( query_layer, key_layer, transpose_b=True ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores + dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores attention_scores = attention_scores / tf.math.sqrt(dk) if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) + # Apply the attention mask is (precomputed for all layers in TFLxmertModel call() function) + attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -721,6 +722,11 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(input_shape, 0) + # Positional Word Embeddings + embedding_output = self.embeddings( + inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"] + ) + # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] @@ -734,8 +740,10 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) if inputs["visual_attention_mask"] is not None: extended_visual_attention_mask = tf.reshape( @@ -745,16 +753,13 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1 ) - extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32) - extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 + extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype) + extended_visual_attention_mask = tf.multiply( + tf.subtract(one_cst, extended_visual_attention_mask), ten_thousand_cst + ) else: extended_visual_attention_mask = None - # Positional Word Embeddings - embedding_output = self.embeddings( - inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"], training=inputs["training"] - ) - # Run Lxmert encoder encoder_outputs = self.encoder( embedding_output, diff --git a/tests/test_modeling_tf_lxmert.py b/tests/test_modeling_tf_lxmert.py index 3615117c232..3b3187eb2d4 100644 --- a/tests/test_modeling_tf_lxmert.py +++ b/tests/test_modeling_tf_lxmert.py @@ -706,10 +706,6 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Lxmert float16 compliant - pass - @slow def test_saved_model_creation_extended(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()