mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 03:28:22 +06:00
Slightly alter Keras dummy loss (#20232)
* Slightly alter Keras dummy loss * Slightly alter Keras dummy loss * Add sample weight to test_keras_fit * Fix test_keras_fit for datasets * Skip the sample_weight stuff for models where the model tester has no batch_size
This commit is contained in:
parent
7f74433814
commit
26ec7928d0
@ -94,7 +94,11 @@ TFModelInputType = Union[
|
|||||||
|
|
||||||
|
|
||||||
def dummy_loss(y_true, y_pred):
|
def dummy_loss(y_true, y_pred):
|
||||||
return tf.reduce_mean(y_pred)
|
if y_pred.shape.rank <= 1:
|
||||||
|
return y_pred
|
||||||
|
else:
|
||||||
|
reduction_axes = list(range(1, y_pred.shape.rank))
|
||||||
|
return tf.reduce_mean(y_pred, axis=reduction_axes)
|
||||||
|
|
||||||
|
|
||||||
class TFModelUtilsMixin:
|
class TFModelUtilsMixin:
|
||||||
|
@ -1544,6 +1544,11 @@ class TFModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
|
if hasattr(self.model_tester, "batch_size"):
|
||||||
|
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
|
||||||
|
else:
|
||||||
|
sample_weight = None
|
||||||
|
|
||||||
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
||||||
model_weights = model.get_weights()
|
model_weights = model.get_weights()
|
||||||
|
|
||||||
@ -1553,6 +1558,7 @@ class TFModelTesterMixin:
|
|||||||
history1 = model.fit(
|
history1 = model.fit(
|
||||||
prepared_for_class,
|
prepared_for_class,
|
||||||
validation_data=prepared_for_class,
|
validation_data=prepared_for_class,
|
||||||
|
sample_weight=sample_weight,
|
||||||
steps_per_epoch=1,
|
steps_per_epoch=1,
|
||||||
validation_steps=1,
|
validation_steps=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@ -1588,6 +1594,7 @@ class TFModelTesterMixin:
|
|||||||
inputs_minus_labels,
|
inputs_minus_labels,
|
||||||
labels,
|
labels,
|
||||||
validation_data=(inputs_minus_labels, labels),
|
validation_data=(inputs_minus_labels, labels),
|
||||||
|
sample_weight=sample_weight,
|
||||||
steps_per_epoch=1,
|
steps_per_epoch=1,
|
||||||
validation_steps=1,
|
validation_steps=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@ -1605,14 +1612,22 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
# Make sure fit works with tf.data.Dataset and results are consistent
|
# Make sure fit works with tf.data.Dataset and results are consistent
|
||||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||||
|
|
||||||
|
if sample_weight is not None:
|
||||||
|
# Add in the sample weight
|
||||||
|
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
|
||||||
|
else:
|
||||||
|
weighted_dataset = dataset
|
||||||
# Pass in all samples as a batch to match other `fit` calls
|
# Pass in all samples as a batch to match other `fit` calls
|
||||||
|
weighted_dataset = weighted_dataset.batch(len(dataset))
|
||||||
dataset = dataset.batch(len(dataset))
|
dataset = dataset.batch(len(dataset))
|
||||||
|
|
||||||
# Reinitialize to fix batchnorm again
|
# Reinitialize to fix batchnorm again
|
||||||
model.set_weights(model_weights)
|
model.set_weights(model_weights)
|
||||||
|
|
||||||
|
# To match the other calls, don't pass sample weights in the validation data
|
||||||
history3 = model.fit(
|
history3 = model.fit(
|
||||||
dataset,
|
weighted_dataset,
|
||||||
validation_data=dataset,
|
validation_data=dataset,
|
||||||
steps_per_epoch=1,
|
steps_per_epoch=1,
|
||||||
validation_steps=1,
|
validation_steps=1,
|
||||||
|
Loading…
Reference in New Issue
Block a user