Slightly alter Keras dummy loss (#20232)

* Slightly alter Keras dummy loss

* Slightly alter Keras dummy loss

* Add sample weight to test_keras_fit

* Fix test_keras_fit for datasets

* Skip the sample_weight stuff for models where the model tester has no batch_size
This commit is contained in:
Matt 2022-11-15 16:58:43 +00:00 committed by GitHub
parent 7f74433814
commit 26ec7928d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 2 deletions

View File

@ -94,7 +94,11 @@ TFModelInputType = Union[
def dummy_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)
if y_pred.shape.rank <= 1:
return y_pred
else:
reduction_axes = list(range(1, y_pred.shape.rank))
return tf.reduce_mean(y_pred, axis=reduction_axes)
class TFModelUtilsMixin:

View File

@ -1544,6 +1544,11 @@ class TFModelTesterMixin:
else:
metrics = []
if hasattr(self.model_tester, "batch_size"):
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
else:
sample_weight = None
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
@ -1553,6 +1558,7 @@ class TFModelTesterMixin:
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
@ -1588,6 +1594,7 @@ class TFModelTesterMixin:
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
sample_weight=sample_weight,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
@ -1605,14 +1612,22 @@ class TFModelTesterMixin:
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
if sample_weight is not None:
# Add in the sample weight
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
else:
weighted_dataset = dataset
# Pass in all samples as a batch to match other `fit` calls
weighted_dataset = weighted_dataset.batch(len(dataset))
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
# To match the other calls, don't pass sample weights in the validation data
history3 = model.fit(
dataset,
weighted_dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,