diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c0f15592866..8bccea12b33 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -274,6 +274,9 @@ class TFSequenceClassificationLoss: def hf_compute_loss(self, labels, logits): if logits.shape.rank == 1 or logits.shape[1] == 1: loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) + if labels.shape.rank == 1: + # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that + labels = tf.expand_dims(labels, axis=-1) else: loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE