mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Rework some TF tests (#8492)
* Update some tests * Small update * Apply style * Use max_position_embeddings * Create a fake attribute * Create a fake attribute * Update wrong name * Wrong TransfoXL model file * Keep the common tests agnostic
This commit is contained in:
parent
f6cdafdec7
commit
24184e73c4
@ -454,7 +454,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
max_input = getattr(self.model_tester, "max_position_embeddings", 512)
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
@ -463,14 +463,16 @@ class TFModelTesterMixin:
|
||||
if self.is_encoder_decoder:
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
|
||||
batch_shape=(2, max_input),
|
||||
name="decoder_input_ids",
|
||||
dtype="int32",
|
||||
),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
||||
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
@ -510,70 +512,64 @@ class TFModelTesterMixin:
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
def check_decoder_attentions_output(outputs):
|
||||
out_len = len(outputs)
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
def check_encoder_attentions_output(outputs):
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["use_cache"] = False
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_decoder_attentions_output(outputs)
|
||||
|
||||
# Check that output attentions can also be changed via the config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -581,10 +577,12 @@ class TFModelTesterMixin:
|
||||
def check_hidden_states_output(config, inputs_dict, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
hidden_states = outputs[-1]
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
|
@ -133,23 +133,21 @@ class TFLongformerModelTester:
|
||||
def create_and_check_longformer_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
model = TFLongformerModel(config=config)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
model = TFLongformerModel(config=config)
|
||||
half_input_mask_length = shape_list(input_mask)[-1] // 2
|
||||
global_attention_mask = tf.concat(
|
||||
@ -160,59 +158,43 @@ class TFLongformerModelTester:
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
|
||||
result = model(input_ids, global_attention_mask=global_attention_mask)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
model = TFLongformerForMaskedLM(config=config)
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
model = TFLongformerForQuestionAnswering(config=config)
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])
|
||||
|
||||
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user