mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix loss computation in TFBertForPreTraining (#17898)
This commit is contained in:
parent
1dfa03f12b
commit
0094565fc5
@ -124,29 +124,20 @@ class TFBertPreTrainingLoss:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])
|
||||
# make sure only labels that are not equal to -100
|
||||
# are taken into account as loss
|
||||
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
|
||||
masked_lm_reduced_logits = tf.boolean_mask(
|
||||
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
|
||||
mask=masked_lm_active_loss,
|
||||
)
|
||||
masked_lm_labels = tf.boolean_mask(
|
||||
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
|
||||
)
|
||||
next_sentence_active_loss = tf.not_equal(tf.reshape(tensor=labels["next_sentence_label"], shape=(-1,)), -100)
|
||||
next_sentence_reduced_logits = tf.boolean_mask(
|
||||
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=next_sentence_active_loss
|
||||
)
|
||||
next_sentence_label = tf.boolean_mask(
|
||||
tensor=tf.reshape(tensor=labels["next_sentence_label"], shape=(-1,)), mask=next_sentence_active_loss
|
||||
)
|
||||
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
|
||||
next_sentence_loss = loss_fn(y_true=next_sentence_label, y_pred=next_sentence_reduced_logits)
|
||||
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(next_sentence_loss)[0]))
|
||||
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
|
||||
# are taken into account for the loss computation
|
||||
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
||||
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1)
|
||||
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask)
|
||||
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
|
||||
|
||||
return masked_lm_loss + next_sentence_loss
|
||||
unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1])
|
||||
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
|
||||
# Just zero out samples where label is -100, no reduction
|
||||
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)
|
||||
|
||||
return reduced_masked_lm_loss + masked_ns_loss
|
||||
|
||||
|
||||
class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
|
Loading…
Reference in New Issue
Block a user