mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Rename compute_loss in TF models (#15207)
* Rename compute_loss to hf_compute_loss to avoid conflicts with the new Keras method * make style * Adding deprecation warning to `compute_loss` * Fix sneaky reference to compute_loss * Replace logger.warning with warnings.warn * Clarifying warning and deprecation timeline
This commit is contained in:
parent
d1f5ca1afd
commit
2708bfa127
@ -170,7 +170,7 @@ class TFCausalLanguageModelingLoss:
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -186,7 +186,7 @@ class TFQuestionAnsweringLoss:
|
||||
Loss function suitable for question answering.
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -207,7 +207,7 @@ class TFTokenClassificationLoss:
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -229,7 +229,7 @@ class TFSequenceClassificationLoss:
|
||||
Loss function suitable for sequence classification.
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
|
||||
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
||||
else:
|
||||
@ -243,7 +243,7 @@ class TFSequenceClassificationLoss:
|
||||
class TFMultipleChoiceLoss:
|
||||
"""Loss function suitable for multiple choice tasks."""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -273,7 +273,7 @@ class TFNextSentencePredictionLoss:
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -869,6 +869,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_loss(self, *args, **kwargs):
|
||||
if hasattr(tf.keras.Model, "compute_loss"):
|
||||
# This will be true in TF 2.8 or greater
|
||||
return super().compute_loss(*args, **kwargs)
|
||||
else:
|
||||
warnings.warn(
|
||||
"The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
|
||||
"method added in TF 2.8. If you want the original HF compute_loss, please call "
|
||||
"hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
|
||||
"calling compute_loss() will get the Keras method instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self.hf_compute_loss(*args, **kwargs)
|
||||
|
||||
def train_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default train_step that cleans up the printed metrics when we use a dummy loss.
|
||||
|
@ -82,7 +82,7 @@ class TFAlbertPreTrainingLoss:
|
||||
MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
||||
def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -941,7 +941,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
|
||||
if inputs["labels"] is not None and inputs["sentence_order_label"] is not None:
|
||||
d_labels = {"labels": inputs["labels"]}
|
||||
d_labels["sentence_order_label"] = inputs["sentence_order_label"]
|
||||
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
|
||||
total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores, sop_scores) + outputs[2:]
|
||||
@ -1058,7 +1058,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.predictions(hidden_states=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1163,7 +1165,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1270,7 +1272,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1385,7 +1387,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
@ -1527,7 +1529,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
|
@ -1412,7 +1412,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -101,7 +101,7 @@ class TFBertPreTrainingLoss:
|
||||
computation.
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
||||
def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
@ -1278,7 +1278,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
if inputs["labels"] is not None and inputs["next_sentence_label"] is not None:
|
||||
d_labels = {"labels": inputs["labels"]}
|
||||
d_labels["next_sentence_label"] = inputs["next_sentence_label"]
|
||||
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
|
||||
total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
||||
@ -1392,7 +1392,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1542,7 +1544,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels=labels, logits=logits)
|
||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1654,7 +1656,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if inputs["next_sentence_label"] is None
|
||||
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
else self.hf_compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1762,7 +1764,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1903,7 +1905,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -2028,7 +2032,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -2149,7 +2153,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -1422,7 +1422,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1395,7 +1395,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -941,7 +941,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + generator_hidden_states[1:]
|
||||
@ -1063,7 +1063,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1196,7 +1196,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
|
||||
logits = self.sequence_summary(outputs[0], training=inputs["training"])
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@ -1309,7 +1309,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1418,7 +1418,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
|
@ -737,7 +737,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -891,7 +891,7 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(
|
||||
loss = self.hf_compute_loss(
|
||||
tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])
|
||||
)
|
||||
|
||||
|
@ -1219,7 +1219,9 @@ class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1322,7 +1324,7 @@ class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceCla
|
||||
pooled_output = self.pooler(sequence_output, training=inputs["training"])
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1416,7 +1418,7 @@ class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassific
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1523,7 +1525,7 @@ class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnswerin
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -1345,7 +1345,9 @@ class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelin
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1449,7 +1451,7 @@ class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenc
|
||||
pooled_output = self.pooler(sequence_output, training=inputs["training"])
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1544,7 +1546,7 @@ class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClass
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1652,7 +1654,7 @@ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsw
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -709,7 +709,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
|
||||
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_projector(prediction_logits)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_logits,) + distilbert_output[1:]
|
||||
@ -810,7 +810,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
|
||||
logits = self.classifier(pooled_output) # (bs, dim)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + distilbert_output[1:]
|
||||
@ -900,7 +900,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1029,7 +1029,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + distilbert_output[1:]
|
||||
@ -1148,7 +1148,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + distilbert_output[1:]
|
||||
|
@ -1267,7 +1267,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + generator_hidden_states[1:]
|
||||
@ -1390,7 +1390,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1522,7 +1522,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
logits = self.sequence_summary(outputs[0])
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@ -1637,7 +1637,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||
logits = self.classifier(discriminator_sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + discriminator_hidden_states[1:]
|
||||
@ -1748,7 +1748,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (
|
||||
|
@ -1389,7 +1389,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
@ -1479,7 +1479,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
logits = self.classifier(pooled_output, training=inputs["training"])
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1600,7 +1600,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
|
||||
logits = self.classifier(pooled_output, training=inputs["training"])
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@ -1706,7 +1706,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1810,7 +1810,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
loss = None
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
|
@ -951,7 +951,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1275,7 +1275,9 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
loss = self.hf_compute_loss(
|
||||
tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels])
|
||||
)
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
|
@ -1160,7 +1160,9 @@ class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingL
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1302,7 +1304,7 @@ class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceC
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1447,7 +1449,7 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -2457,7 +2457,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
)
|
||||
lm_logits = self.led.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@ -2556,7 +2556,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True,
|
||||
|
@ -2161,7 +2161,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@ -2303,7 +2303,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
@ -2450,7 +2450,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -2596,7 +2596,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -2715,7 +2715,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -1438,7 +1438,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1423,7 +1423,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1179,7 +1179,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@ -1293,7 +1293,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if inputs["next_sentence_label"] is None
|
||||
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
else self.hf_compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1406,7 +1406,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1526,7 +1526,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
@ -1671,7 +1671,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -1797,7 +1797,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -866,7 +866,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@ -988,7 +988,7 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output, training=training)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1112,7 +1112,7 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -1224,7 +1224,7 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1336,7 +1336,7 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -658,7 +658,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -953,7 +953,7 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(
|
||||
loss = self.hf_compute_loss(
|
||||
tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])
|
||||
)
|
||||
|
||||
|
@ -1446,7 +1446,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1418,12 +1418,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
|
||||
target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1)
|
||||
rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
|
||||
loss = self.compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)
|
||||
loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)
|
||||
|
||||
return loss
|
||||
|
||||
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
|
||||
def compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
|
||||
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True,
|
||||
|
@ -1133,7 +1133,9 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1275,7 +1277,7 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels=labels, logits=logits)
|
||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1384,7 +1386,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1521,7 +1523,9 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -1631,7 +1635,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1741,7 +1745,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -1164,7 +1164,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@ -1312,7 +1312,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels=labels, logits=logits)
|
||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1447,7 +1447,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1577,7 +1577,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@ -1695,7 +1695,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1809,7 +1809,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -939,7 +939,9 @@ class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingL
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1035,7 +1037,7 @@ class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingL
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels=labels, logits=logits)
|
||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1154,7 +1156,7 @@ class TFRoFormerForSequenceClassification(TFRoFormerPreTrainedModel, TFSequenceC
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(hidden_states=outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1286,7 +1288,9 @@ class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLos
|
||||
logits = self.sequence_summary(inputs=outputs[0], training=inputs["training"])
|
||||
logits = self.classifier(inputs=logits)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@ -1394,7 +1398,7 @@ class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassif
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1501,7 +1505,7 @@ class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnswer
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
@ -1472,7 +1472,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
logits = tf.cast(logits, tf.float32)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
@ -1161,7 +1161,9 @@ class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None
|
||||
if inputs["labels"] is None
|
||||
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1741,7 +1743,7 @@ class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassif
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -1179,7 +1179,7 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(
|
||||
loss = self.hf_compute_loss(
|
||||
tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])
|
||||
)
|
||||
|
||||
|
@ -844,7 +844,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(inputs=sequence_output[:, 0, :])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -1013,7 +1013,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
|
||||
logits = self.sequence_summary(output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1166,7 +1166,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
@ -1288,7 +1288,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1406,7 +1406,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + transformer_outputs[1:]
|
||||
|
@ -1393,7 +1393,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1508,7 +1508,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
output = self.sequence_summary(output)
|
||||
logits = self.logits_proj(output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1656,7 +1656,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
@ -1778,7 +1778,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.classifier(output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@ -1900,7 +1900,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + transformer_outputs[1:]
|
||||
|
@ -1122,7 +1122,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
|
||||
loss = (
|
||||
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
@ -1264,7 +1264,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels=labels, logits=logits)
|
||||
loss = self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
@ -1394,7 +1394,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(hidden_states=outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1534,7 +1534,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
logits = self.sequence_summary(inputs=outputs[0], training=inputs["training"])
|
||||
logits = self.classifier(inputs=logits)
|
||||
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@ -1642,7 +1642,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(inputs=sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1752,7 +1752,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
@ -3152,7 +3152,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@ -3251,7 +3251,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True,
|
||||
|
@ -1064,7 +1064,7 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "compute_loss", None):
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
# The number of elements in the loss should be the same as the number of elements in the label
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
added_label = prepared_for_class[
|
||||
|
Loading…
Reference in New Issue
Block a user