Fix loss computation in TFBertForPreTraining (#17898)

This commit is contained in:
Matt 2022-06-28 12:44:56 +01:00 committed by GitHub
parent 1dfa03f12b
commit 0094565fc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -124,29 +124,20 @@ class TFBertPreTrainingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
mask=masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
)
next_sentence_active_loss = tf.not_equal(tf.reshape(tensor=labels["next_sentence_label"], shape=(-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=next_sentence_active_loss
)
next_sentence_label = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["next_sentence_label"], shape=(-1,)), mask=next_sentence_active_loss
)
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
next_sentence_loss = loss_fn(y_true=next_sentence_label, y_pred=next_sentence_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(next_sentence_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1)
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask)
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
return masked_lm_loss + next_sentence_loss
unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1])
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)
return reduced_masked_lm_loss + masked_ns_loss
class TFBertEmbeddings(tf.keras.layers.Layer):