From 26ec7928d0077aaf4084a36ee05a253195497f51 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 15 Nov 2022 16:58:43 +0000 Subject: [PATCH] Slightly alter Keras dummy loss (#20232) * Slightly alter Keras dummy loss * Slightly alter Keras dummy loss * Add sample weight to test_keras_fit * Fix test_keras_fit for datasets * Skip the sample_weight stuff for models where the model tester has no batch_size --- src/transformers/modeling_tf_utils.py | 6 +++++- tests/test_modeling_tf_common.py | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6c473bfac62..f79e3dedff5 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -94,7 +94,11 @@ TFModelInputType = Union[ def dummy_loss(y_true, y_pred): - return tf.reduce_mean(y_pred) + if y_pred.shape.rank <= 1: + return y_pred + else: + reduction_axes = list(range(1, y_pred.shape.rank)) + return tf.reduce_mean(y_pred, axis=reduction_axes) class TFModelUtilsMixin: diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2f4c1225764..6b82ea7bc06 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1544,6 +1544,11 @@ class TFModelTesterMixin: else: metrics = [] + if hasattr(self.model_tester, "batch_size"): + sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32) + else: + sample_weight = None + model(model.dummy_inputs) # Build the model so we can get some constant weights model_weights = model.get_weights() @@ -1553,6 +1558,7 @@ class TFModelTesterMixin: history1 = model.fit( prepared_for_class, validation_data=prepared_for_class, + sample_weight=sample_weight, steps_per_epoch=1, validation_steps=1, shuffle=False, @@ -1588,6 +1594,7 @@ class TFModelTesterMixin: inputs_minus_labels, labels, validation_data=(inputs_minus_labels, labels), + sample_weight=sample_weight, steps_per_epoch=1, validation_steps=1, shuffle=False, @@ -1605,14 +1612,22 @@ class TFModelTesterMixin: # Make sure fit works with tf.data.Dataset and results are consistent dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) + + if sample_weight is not None: + # Add in the sample weight + weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32))) + else: + weighted_dataset = dataset # Pass in all samples as a batch to match other `fit` calls + weighted_dataset = weighted_dataset.batch(len(dataset)) dataset = dataset.batch(len(dataset)) # Reinitialize to fix batchnorm again model.set_weights(model_weights) + # To match the other calls, don't pass sample weights in the validation data history3 = model.fit( - dataset, + weighted_dataset, validation_data=dataset, steps_per_epoch=1, validation_steps=1,