From dfc76b25426d75d5dce489bd18cfd6a51fb01b97 Mon Sep 17 00:00:00 2001 From: amyeroberts Date: Thu, 9 Jun 2022 09:50:03 +0200 Subject: [PATCH] has_attentions - consistent test skipping logic and tf tests (#17495) --- .../models/convnext/test_modeling_convnext.py | 4 + tests/models/cvt/test_modeling_cvt.py | 4 + tests/models/flava/test_modeling_flava.py | 4 + .../poolformer/test_modeling_poolformer.py | 4 + tests/models/regnet/test_modeling_regnet.py | 4 + tests/models/resnet/test_modeling_resnet.py | 4 + tests/models/van/test_modeling_van.py | 4 + tests/test_modeling_common.py | 204 +++++++++--------- tests/test_modeling_tf_common.py | 24 ++- 9 files changed, 141 insertions(+), 115 deletions(-) diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py index f12a21bfe64..46ef3ce7170 100644 --- a/tests/models/convnext/test_modeling_convnext.py +++ b/tests/models/convnext/test_modeling_convnext.py @@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): def create_and_test_config_common_properties(self): return + @unittest.skip(reason="ConvNext does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip(reason="ConvNext does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/cvt/test_modeling_cvt.py b/tests/models/cvt/test_modeling_cvt.py index 3791c75e8c9..b88f22d982b 100644 --- a/tests/models/cvt/test_modeling_cvt.py +++ b/tests/models/cvt/test_modeling_cvt.py @@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase): def create_and_test_config_common_properties(self): return + @unittest.skip(reason="Cvt does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip(reason="Cvt does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 8829c55ac8b..62b89e3977c 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase): expected_arg_names = ["pixel_values"] self.assertListEqual(arg_names[:1], expected_arg_names) + @unittest.skip(reason="Flava does not output attentions") + def test_attention_outputs(self): + pass + def test_model_common_attributes(self): # No embedding in multimodal model pass diff --git a/tests/models/poolformer/test_modeling_poolformer.py b/tests/models/poolformer/test_modeling_poolformer.py index 9bb8fa2e29c..7dc47d2c77f 100644 --- a/tests/models/poolformer/test_modeling_poolformer.py +++ b/tests/models/poolformer/test_modeling_poolformer.py @@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip(reason="PoolFormer does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip("PoolFormer does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py index 02695dbf643..4879bf259ef 100644 --- a/tests/models/regnet/test_modeling_regnet.py +++ b/tests/models/regnet/test_modeling_regnet.py @@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase): def create_and_test_config_common_properties(self): return + @unittest.skip(reason="RegNet does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip(reason="RegNet does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index f289c5c3df8..83f08b68afb 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): def create_and_test_config_common_properties(self): return + @unittest.skip(reason="ResNet does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip(reason="ResNet does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/van/test_modeling_van.py b/tests/models/van/test_modeling_van.py index 3e5b7fb1dfc..6b6a672b9b4 100644 --- a/tests/models/van/test_modeling_van.py +++ b/tests/models/van/test_modeling_van.py @@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase): def create_and_test_config_common_properties(self): return + @unittest.skip(reason="Van does not output attentions") + def test_attention_outputs(self): + pass + @unittest.skip(reason="Van does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 83927bb27fd..747647874c2 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -485,123 +485,119 @@ class ModelTesterMixin: loss.backward() def test_attention_outputs(self): - if not self.has_attentions: - pass + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True - else: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - seq_len = getattr(self.model_tester, "seq_length", None) - decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + if self.is_encoder_decoder: + correct_outlen = 5 - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - out_len = len(outputs) + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Question Answering model returns start_logits and end_logits + if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned - if self.is_encoder_decoder: - correct_outlen = 5 + self.assertEqual(out_len, correct_outlen) - # loss is at first position - if "labels" in inputs_dict: - correct_outlen += 1 # loss is added to beginning - # Question Answering model returns start_logits and end_logits - if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): - correct_outlen += 1 # start_logits and end_logits instead of only 1 output - if "past_key_values" in outputs: - correct_outlen += 1 # past_key_values have been returned + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) - self.assertEqual(out_len, correct_outlen) + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) - # decoder attentions - decoder_attentions = outputs.decoder_attentions - self.assertIsInstance(decoder_attentions, (list, tuple)) - self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(decoder_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], - ) + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - # cross attentions - cross_attentions = outputs.cross_attentions - self.assertIsInstance(cross_attentions, (list, tuple)) - self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(cross_attentions[0].shape[-3:]), - [ - self.model_tester.num_attention_heads, - decoder_seq_length, - encoder_key_length, - ], - ) + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) @slow def test_torchscript_simple(self): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 65eebdf269b..fa439704a88 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -978,9 +978,10 @@ class TFModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - tuple_inputs = self._prepare_for_class(inputs_dict, model_class) - dict_inputs = self._prepare_for_class(inputs_dict, model_class) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + if self.has_attentions: + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) # Not all models accept "labels" in the forward pass (yet :) ) if "labels" in inspect.signature(model.call).parameters.keys(): @@ -992,15 +993,16 @@ class TFModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + if self.has_attentions: + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence( - model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} - ) + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence( + model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + ) def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()