mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove hard-coded uses of float32 to fix mixed precision use (#6648)
This commit is contained in:
parent
0344428f79
commit
4fca874ea9
@ -215,8 +215,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
@ -281,7 +281,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
attention_scores = tf.matmul(
|
||||
query_layer, key_layer, transpose_b=True
|
||||
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
|
||||
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -613,6 +613,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
@ -626,7 +628,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
@ -640,7 +642,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
|
@ -134,8 +134,8 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
@ -194,7 +194,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
||||
config_class = ElectraConfig
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape):
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
@ -211,7 +211,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
return extended_attention_mask
|
||||
@ -314,11 +314,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user