From 3a1a56a8fef3bc8586cb150186a10f5776dcb7ef Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Oct 2022 14:48:27 +0100 Subject: [PATCH] Fix for sequence regression fit() in TF (#19316) Co-authored-by: Your Name --- src/transformers/modeling_tf_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c0f15592866..8bccea12b33 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -274,6 +274,9 @@ class TFSequenceClassificationLoss: def hf_compute_loss(self, labels, logits): if logits.shape.rank == 1 or logits.shape[1] == 1: loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) + if labels.shape.rank == 1: + # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that + labels = tf.expand_dims(labels, axis=-1) else: loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE