Rework some TF tests (#8492)

* Update some tests

* Small update

* Apply style

* Use max_position_embeddings

* Create a fake attribute

* Create a fake attribute

* Update wrong name

* Wrong TransfoXL model file

* Keep the common tests agnostic
This commit is contained in:
Julien Plu 2020-11-13 23:07:17 +01:00 committed by GitHub
parent f6cdafdec7
commit 24184e73c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 80 deletions

View File

@ -454,7 +454,7 @@ class TFModelTesterMixin:
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
max_input = getattr(self.model_tester, "max_position_embeddings", 512)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
@ -463,14 +463,16 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
input_ids = {
"decoder_input_ids": tf.keras.Input(
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
batch_shape=(2, max_input),
name="decoder_input_ids",
dtype="int32",
),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
}
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
# Prepare our model
model = model_class(config)
@ -510,70 +512,64 @@ class TFModelTesterMixin:
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
def check_decoder_attentions_output(outputs):
out_len = len(outputs)
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs.decoder_attentions
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
def check_encoder_attentions_output(outputs):
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)
if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs.decoder_attentions
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(config.output_hidden_states, False)
check_decoder_attentions_output(outputs)
# Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True)
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
check_encoder_attentions_output(outputs)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -581,10 +577,12 @@ class TFModelTesterMixin:
def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
hidden_states = [t.numpy() for t in outputs[-1]]
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
hidden_states = outputs[-1]
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),

View File

@ -133,23 +133,21 @@ class TFLongformerModelTester:
def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerModel(config=config)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerModel(config=config)
half_input_mask_length = shape_list(input_mask)[-1] // 2
global_attention_mask = tf.concat(
@ -160,59 +158,43 @@ class TFLongformerModelTester:
axis=-1,
)
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerForMaskedLM(config=config)
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
model = TFLongformerForQuestionAnswering(config=config)
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()