mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-28 08:42:23 +06:00
Fix for sequence regression fit() in TF (#19316)
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
fe10796f4f
commit
3a1a56a8fe
@ -274,6 +274,9 @@ class TFSequenceClassificationLoss:
|
|||||||
def hf_compute_loss(self, labels, logits):
|
def hf_compute_loss(self, labels, logits):
|
||||||
if logits.shape.rank == 1 or logits.shape[1] == 1:
|
if logits.shape.rank == 1 or logits.shape[1] == 1:
|
||||||
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
if labels.shape.rank == 1:
|
||||||
|
# MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
|
||||||
|
labels = tf.expand_dims(labels, axis=-1)
|
||||||
else:
|
else:
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
Loading…
Reference in New Issue
Block a user