mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Speed up TF tests by reducing hidden layer counts (#24595)
* hidden layers, huh, what are they good for (absolutely nothing) * Some tests break with 1 hidden layer, use 2 * Use 1 hidden layer in a few slow models * Use num_hidden_layers=2 everywhere * Slightly higher tol for groupvit * Slightly higher tol for groupvit
This commit is contained in:
parent
3441ad7d43
commit
134caef31a
@ -56,7 +56,7 @@ class TFAlbertModelTester:
|
||||
vocab_size=99,
|
||||
embedding_size=16,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -80,7 +80,7 @@ class TFAlbertModelTester:
|
||||
self.vocab_size = 99
|
||||
self.embedding_size = 16
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -52,7 +52,7 @@ class TFBartModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -57,7 +57,7 @@ class TFBertModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -80,7 +80,7 @@ class TFBertModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -48,7 +48,7 @@ class TFBlenderbotModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -48,7 +48,7 @@ class TFBlenderbotSmallModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -64,7 +64,7 @@ class TFBlipVisionModelTester:
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
@ -207,7 +207,7 @@ class TFBlipTextModelTester:
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
|
@ -46,7 +46,7 @@ class BlipTextModelTester:
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
|
@ -57,7 +57,7 @@ class TFCLIPVisionModelTester:
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
@ -328,7 +328,7 @@ class TFCLIPTextModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
|
@ -51,7 +51,7 @@ class TFConvBertModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -74,7 +74,7 @@ class TFConvBertModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 384
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -52,7 +52,7 @@ class TFCTRLModelTester(object):
|
||||
self.use_mc_token_ids = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -61,7 +61,7 @@ class TFData2VecVisionModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -50,7 +50,7 @@ class TFDebertaModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -73,7 +73,7 @@ class TFDebertaModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -50,7 +50,7 @@ class TFDebertaV2ModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -60,7 +60,7 @@ class TFDeiTModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -54,7 +54,7 @@ class TFDistilBertModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -53,7 +53,7 @@ class TFDPRModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -54,7 +54,7 @@ class TFElectraModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -53,7 +53,7 @@ class TFEsmModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -61,7 +61,7 @@ class TFFlaubertModelTester:
|
||||
self.vocab_size = 99
|
||||
self.n_special = 0
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
|
@ -55,7 +55,7 @@ class TFGPT2ModelTester:
|
||||
self.use_mc_token_ids = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -51,7 +51,7 @@ class TFGPTJModelTester:
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.rotary_dim = 4
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -150,6 +150,10 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as this model tends to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGroupViTVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
@ -381,7 +385,7 @@ class TFGroupViTTextModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
@ -459,6 +463,10 @@ class TFGroupViTTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as this model tends to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGroupViTTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GroupViTTextConfig, hidden_size=37)
|
||||
@ -581,6 +589,10 @@ class TFGroupViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
test_attention_outputs = False
|
||||
test_onnx = False
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as this model tends to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGroupViTModelTester(self)
|
||||
|
||||
|
@ -59,7 +59,7 @@ class TFHubertModelTester:
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
|
@ -52,7 +52,7 @@ class TFLayoutLMModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -69,7 +69,7 @@ class TFLayoutLMv3ModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=36,
|
||||
num_hidden_layers=3,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -47,7 +47,7 @@ class TFLEDModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -56,7 +56,7 @@ class TFLongformerModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -49,7 +49,7 @@ class TFMarianModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -47,7 +47,7 @@ class TFMBartModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -97,7 +97,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
embedding_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -51,7 +51,7 @@ class TFMPNetModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=64,
|
||||
hidden_act="gelu",
|
||||
|
@ -53,7 +53,7 @@ class TFOpenAIGPTModelTester:
|
||||
self.use_mc_token_ids = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -47,7 +47,7 @@ class TFPegasusModelTester:
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
|
@ -54,7 +54,7 @@ class TFRemBertModelTester:
|
||||
hidden_size=32,
|
||||
input_embedding_size=18,
|
||||
output_embedding_size=43,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -79,7 +79,7 @@ class TFRemBertModelTester:
|
||||
self.hidden_size = 32
|
||||
self.input_embedding_size = input_embedding_size
|
||||
self.output_embedding_size = output_embedding_size
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -56,7 +56,7 @@ class TFRobertaModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -57,7 +57,7 @@ class TFRobertaPreLayerNormModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -56,7 +56,7 @@ class TFRoFormerModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -79,7 +79,7 @@ class TFRoFormerModelTester:
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
@ -46,7 +46,7 @@ class TFT5ModelTester:
|
||||
self.vocab_size = 99
|
||||
self.n_positions = 14
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.d_ff = 37
|
||||
self.relative_attention_num_buckets = 8
|
||||
@ -325,7 +325,7 @@ class TFT5EncoderOnlyModelTester:
|
||||
# For common tests
|
||||
use_attention_mask=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
|
@ -77,7 +77,7 @@ class TFTapasModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -59,7 +59,7 @@ class TFTransfoXLModelTester:
|
||||
self.d_head = 8
|
||||
self.d_inner = 128
|
||||
self.div_val = 2
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.scope = None
|
||||
self.seed = 1
|
||||
self.eos_token_id = 0
|
||||
|
@ -52,7 +52,7 @@ class TFViTModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -60,7 +60,7 @@ class TFViTMAEModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
|
@ -130,7 +130,7 @@ class TFWav2Vec2ModelTester:
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
|
@ -51,7 +51,7 @@ class TFXGLMModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
d_model=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
ffn_dim=37,
|
||||
activation_function="gelu",
|
||||
|
@ -61,7 +61,7 @@ class TFXLMModelTester:
|
||||
self.vocab_size = 99
|
||||
self.n_special = 0
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
|
@ -61,7 +61,7 @@ class TFXLNetModelTester:
|
||||
self.hidden_size = 32
|
||||
self.num_attention_heads = 4
|
||||
self.d_inner = 128
|
||||
self.num_hidden_layers = 5
|
||||
self.num_hidden_layers = 2
|
||||
self.type_sequence_label_size = 2
|
||||
self.untie_r = True
|
||||
self.bi_data = False
|
||||
|
@ -1527,36 +1527,6 @@ class TFModelTesterMixin:
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||
|
||||
# Make sure fit works with tf.data.Dataset and results are consistent
|
||||
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||
|
||||
if sample_weight is not None:
|
||||
# Add in the sample weight
|
||||
weighted_dataset = dataset.map(lambda x: (x, None, tf.convert_to_tensor(0.5, dtype=tf.float32)))
|
||||
else:
|
||||
weighted_dataset = dataset
|
||||
# Pass in all samples as a batch to match other `fit` calls
|
||||
weighted_dataset = weighted_dataset.batch(len(dataset))
|
||||
dataset = dataset.batch(len(dataset))
|
||||
# Reinitialize to fix batchnorm again
|
||||
model.set_weights(model_weights)
|
||||
|
||||
# To match the other calls, don't pass sample weights in the validation data
|
||||
history3 = model.fit(
|
||||
weighted_dataset,
|
||||
validation_data=dataset,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss3 = history3.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss3))
|
||||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||
self.check_keras_fit_results(val_loss1, val_loss3)
|
||||
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
|
||||
|
||||
def test_int_support(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
Loading…
Reference in New Issue
Block a user