mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-28 00:32:25 +06:00
Fix for sequence regression fit() in TF (#19316)
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
fe10796f4f
commit
3a1a56a8fe
@ -274,6 +274,9 @@ class TFSequenceClassificationLoss:
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
if logits.shape.rank == 1 or logits.shape[1] == 1:
|
||||
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
||||
if labels.shape.rank == 1:
|
||||
# MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
|
||||
labels = tf.expand_dims(labels, axis=-1)
|
||||
else:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
|
Loading…
Reference in New Issue
Block a user