Fix for sequence regression fit() in TF (#19316)

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Matt 2022-10-04 14:48:27 +01:00 committed by GitHub
parent fe10796f4f
commit 3a1a56a8fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -274,6 +274,9 @@ class TFSequenceClassificationLoss:
def hf_compute_loss(self, labels, logits):
if logits.shape.rank == 1 or logits.shape[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
if labels.shape.rank == 1:
# MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
labels = tf.expand_dims(labels, axis=-1)
else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE