Fix broken test for models with batchnorm (#17841)

* Fix tests that broke when models used batchnorm

* Initializing the model twice does not actually...
...give you the same weights each time.
I am good at machine learning.

* Fix speed regression
This commit is contained in:
Matt 2022-06-23 15:59:53 +01:00 committed by GitHub
parent 18c263c4b6
commit 1a7ef3349f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1383,6 +1383,10 @@ class TFModelTesterMixin:
else:
metrics = []
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
# Run eagerly to save some expensive compilation times
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
@ -1394,6 +1398,11 @@ class TFModelTesterMixin:
)
val_loss1 = history1.history["val_loss"][0]
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
history2 = model.fit(
inputs_minus_labels,
labels,
@ -1403,7 +1412,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
@ -1416,6 +1425,10 @@ class TFModelTesterMixin:
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
history3 = model.fit(
dataset,
validation_data=dataset,