mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Slightly alter Keras dummy loss (#20232)
* Slightly alter Keras dummy loss * Slightly alter Keras dummy loss * Add sample weight to test_keras_fit * Fix test_keras_fit for datasets * Skip the sample_weight stuff for models where the model tester has no batch_size
This commit is contained in:
parent
7f74433814
commit
26ec7928d0
@ -94,7 +94,11 @@ TFModelInputType = Union[
|
||||
|
||||
|
||||
def dummy_loss(y_true, y_pred):
|
||||
return tf.reduce_mean(y_pred)
|
||||
if y_pred.shape.rank <= 1:
|
||||
return y_pred
|
||||
else:
|
||||
reduction_axes = list(range(1, y_pred.shape.rank))
|
||||
return tf.reduce_mean(y_pred, axis=reduction_axes)
|
||||
|
||||
|
||||
class TFModelUtilsMixin:
|
||||
|
@ -1544,6 +1544,11 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
metrics = []
|
||||
|
||||
if hasattr(self.model_tester, "batch_size"):
|
||||
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
||||
model_weights = model.get_weights()
|
||||
|
||||
@ -1553,6 +1558,7 @@ class TFModelTesterMixin:
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
sample_weight=sample_weight,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
@ -1588,6 +1594,7 @@ class TFModelTesterMixin:
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
sample_weight=sample_weight,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
@ -1605,14 +1612,22 @@ class TFModelTesterMixin:
|
||||
|
||||
# Make sure fit works with tf.data.Dataset and results are consistent
|
||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||
|
||||
if sample_weight is not None:
|
||||
# Add in the sample weight
|
||||
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
|
||||
else:
|
||||
weighted_dataset = dataset
|
||||
# Pass in all samples as a batch to match other `fit` calls
|
||||
weighted_dataset = weighted_dataset.batch(len(dataset))
|
||||
dataset = dataset.batch(len(dataset))
|
||||
|
||||
# Reinitialize to fix batchnorm again
|
||||
model.set_weights(model_weights)
|
||||
|
||||
# To match the other calls, don't pass sample weights in the validation data
|
||||
history3 = model.fit(
|
||||
dataset,
|
||||
weighted_dataset,
|
||||
validation_data=dataset,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
|
Loading…
Reference in New Issue
Block a user