mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix broken test for models with batchnorm (#17841)
* Fix tests that broke when models used batchnorm * Initializing the model twice does not actually... ...give you the same weights each time. I am good at machine learning. * Fix speed regression
This commit is contained in:
parent
18c263c4b6
commit
1a7ef3349f
@ -1383,6 +1383,10 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
metrics = []
|
||||
|
||||
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
||||
model_weights = model.get_weights()
|
||||
|
||||
# Run eagerly to save some expensive compilation times
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
@ -1394,6 +1398,11 @@ class TFModelTesterMixin:
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
|
||||
# We reinitialize the model here even though our learning rate was zero
|
||||
# because BatchNorm updates weights by means other than gradient descent.
|
||||
model.set_weights(model_weights)
|
||||
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
@ -1403,7 +1412,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||
for key in history1.history.keys():
|
||||
@ -1416,6 +1425,10 @@ class TFModelTesterMixin:
|
||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||
# Pass in all samples as a batch to match other `fit` calls
|
||||
dataset = dataset.batch(len(dataset))
|
||||
|
||||
# Reinitialize to fix batchnorm again
|
||||
model.set_weights(model_weights)
|
||||
|
||||
history3 = model.fit(
|
||||
dataset,
|
||||
validation_data=dataset,
|
||||
|
Loading…
Reference in New Issue
Block a user