Fix tf boolean mask in graph mode (#6741)

This commit is contained in:
Jay Yip 2020-08-26 17:15:35 +08:00 committed by GitHub
parent 925f34bbbd
commit 461ae86812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -137,7 +137,7 @@ class TFCausalLanguageModelingLoss:
)
# make sure only labels that are not equal to -100
# are taken into account as loss
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)