diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 9b858fc3c13..94bf5dcfbe4 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -542,7 +542,8 @@ class DeiTModel(DeiTPreTrainedModel): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, @@ -662,7 +663,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels if not return_dict: - output = (reconstructed_pixel_values,) + outputs[2:] + output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output return MaskedLMOutput( @@ -775,7 +776,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( @@ -882,7 +883,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): logits = (cls_logits + distillation_logits) / 2 if not return_dict: - output = (logits, cls_logits, distillation_logits) + outputs[2:] + output = (logits, cls_logits, distillation_logits) + outputs[1:] return output return DeiTForImageClassificationWithTeacherOutput( diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index ab742a1279c..6c5fd238523 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -750,7 +750,8 @@ class DPTModel(DPTPreTrainedModel): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, @@ -938,7 +939,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel): return_dict=return_dict, ) - hidden_states = outputs.hidden_states if return_dict else outputs[2] + hidden_states = outputs.hidden_states if return_dict else outputs[1] # only keep certain features based on config.backbone_out_indices # note that the hidden_states also include the initial embeddings @@ -956,9 +957,9 @@ class DPTForDepthEstimation(DPTPreTrainedModel): if not return_dict: if output_hidden_states: - output = (predicted_depth,) + outputs[2:] + output = (predicted_depth,) + outputs[1:] else: - output = (predicted_depth,) + outputs[3:] + output = (predicted_depth,) + outputs[2:] return ((loss,) + output) if loss is not None else output return DepthEstimatorOutput( @@ -1083,7 +1084,7 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel): return_dict=return_dict, ) - hidden_states = outputs.hidden_states if return_dict else outputs[2] + hidden_states = outputs.hidden_states if return_dict else outputs[1] # only keep certain features based on config.backbone_out_indices # note that the hidden_states also include the initial embeddings @@ -1120,9 +1121,9 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel): if not return_dict: if output_hidden_states: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] else: - output = (logits,) + outputs[3:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SemanticSegmenterOutput( diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index f4a88e51b51..b2fc044fcb0 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -585,7 +585,8 @@ class ViTModel(ViTPreTrainedModel): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, @@ -706,7 +707,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels if not return_dict: - output = (reconstructed_pixel_values,) + outputs[2:] + output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output return MaskedLMOutput( @@ -798,8 +799,9 @@ class ViTForImageClassification(ViTPreTrainedModel): elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) + if not return_dict: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( diff --git a/tests/beit/test_modeling_beit.py b/tests/beit/test_modeling_beit.py index 5b4421b4d77..3f375d3a310 100644 --- a/tests/beit/test_modeling_beit.py +++ b/tests/beit/test_modeling_beit.py @@ -41,7 +41,7 @@ if is_torch_available(): BeitForSemanticSegmentation, BeitModel, ) - from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -96,6 +96,10 @@ class BeitModelTester: self.out_indices = out_indices self.num_labels = num_labels + # in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -132,22 +136,16 @@ class BeitModelTester: model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels): model = BeitForMaskedImageModeling(config=config) model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): config.num_labels = self.type_sequence_label_size @@ -312,16 +310,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - # in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 1 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + # BEiT has a different seq_length + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -332,7 +322,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config @@ -349,7 +339,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): self.assertListEqual( list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, seq_len, seq_len], ) out_len = len(outputs) @@ -369,7 +359,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, seq_len, seq_len], ) def test_hidden_states_output(self): @@ -381,7 +371,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 @@ -389,10 +379,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): self.assertEqual(len(hidden_states), expected_num_layers) # BEiT has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = num_patches + 1 + seq_length = self.model_tester.expected_seq_length self.assertListEqual( list(hidden_states[0].shape[-2:]), diff --git a/tests/beit/test_modeling_flax_beit.py b/tests/beit/test_modeling_flax_beit.py index a1d1fe093b9..8977ab6542e 100644 --- a/tests/beit/test_modeling_flax_beit.py +++ b/tests/beit/test_modeling_flax_beit.py @@ -75,6 +75,10 @@ class FlaxBeitModelTester(unittest.TestCase): self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + # in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -104,20 +108,14 @@ class FlaxBeitModelTester(unittest.TestCase): model = FlaxBeitModel(config=config) result = model(pixel_values) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = (self.image_size, self.image_size) - patch_size = (self.patch_size, self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) def create_and_check_for_masked_lm(self, config, pixel_values, labels): model = FlaxBeitForMaskedImageModeling(config=config) result = model(pixel_values) - # expected sequence length = num_patches - image_size = (self.image_size, self.image_size) - patch_size = (self.patch_size, self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size @@ -151,13 +149,11 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): self.config_tester.run_common_tests() # We need to override this test because in Beit, the seq_len equals the number of patches + 1 - # we compute that here def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - num_patches = (config.image_size // config.patch_size) ** 2 - seq_length = num_patches + 1 + seq_length = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -209,7 +205,7 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): expected_arg_names = ["pixel_values"] self.assertListEqual(arg_names[:1], expected_arg_names) - # We neeed to override this test because Beit expects pixel_values instead of input_ids + # We need to override this test because Beit expects pixel_values instead of input_ids def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -234,12 +230,10 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertEqual(jitted_output.shape, output.shape) # We need to override this test because in Beit, the seq_len equals the number of patches + 1 - # we compute that here def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) - num_patches = (config.image_size // config.patch_size) ** 2 - seq_length = num_patches + 1 # we add 1 for the [CLS] token + seq_length = self.model_tester.expected_seq_length outputs = model(**self._prepare_for_class(inputs_dict, model_class)) hidden_states = outputs.hidden_states diff --git a/tests/deit/test_modeling_deit.py b/tests/deit/test_modeling_deit.py index f0d97c1369c..f8723c18756 100644 --- a/tests/deit/test_modeling_deit.py +++ b/tests/deit/test_modeling_deit.py @@ -41,7 +41,7 @@ if is_torch_available(): DeiTForMaskedImageModeling, DeiTModel, ) - from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -92,6 +92,10 @@ class DeiTModelTester: self.scope = scope self.encoder_stride = encoder_stride + # in DeiT, the expected seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 2 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -125,11 +129,9 @@ class DeiTModelTester: model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches + 2 (we add 2 for the [CLS] and distillation tokens) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 2, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size @@ -212,16 +214,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - # in DeiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 2 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -232,7 +225,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config @@ -243,19 +236,13 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) out_len = len(outputs) # Check attention is always last and order is fine @@ -267,27 +254,15 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + self.assertEqual(out_len + 1, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): @@ -298,18 +273,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 ) self.assertEqual(len(hidden_states), expected_num_layers) - # DeiT has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = num_patches + 2 + seq_length = self.model_tester.expected_seq_length self.assertListEqual( list(hidden_states[0].shape[-2:]), diff --git a/tests/dpt/test_modeling_dpt.py b/tests/dpt/test_modeling_dpt.py index aaa0c66f2ee..08bb550e0e5 100644 --- a/tests/dpt/test_modeling_dpt.py +++ b/tests/dpt/test_modeling_dpt.py @@ -81,6 +81,9 @@ class DPTModelTester: self.initializer_range = initializer_range self.num_labels = num_labels self.scope = scope + # expected sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -115,9 +118,9 @@ class DPTModelTester: model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - num_patches = (config.image_size // config.patch_size) ** 2 - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) def create_and_check_for_depth_estimation(self, config, pixel_values, labels): config.num_labels = self.num_labels @@ -206,8 +209,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): config.return_dict = True # in DPT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - num_patches = (config.image_size // config.patch_size) ** 2 - seq_len = num_patches + 1 + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -274,8 +276,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): self.assertEqual(len(hidden_states), expected_num_layers) # DPT has a different seq_length - num_patches = (config.image_size // config.patch_size) ** 2 - seq_len = num_patches + 1 + seq_len = self.model_tester.expected_seq_length self.assertListEqual( list(hidden_states[0].shape[-2:]), diff --git a/tests/vit/test_modeling_flax_vit.py b/tests/vit/test_modeling_flax_vit.py index 63808a3cdf4..0af2123c905 100644 --- a/tests/vit/test_modeling_flax_vit.py +++ b/tests/vit/test_modeling_flax_vit.py @@ -67,6 +67,10 @@ class FlaxViTModelTester(unittest.TestCase): self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -120,13 +124,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): self.config_tester.run_common_tests() # We need to override this test because in ViT, the seq_len equals the number of patches + 1 - # we compute that here def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - num_patches = (config.image_size // config.patch_size) ** 2 - seq_length = num_patches + 1 + seq_length = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -203,12 +205,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertEqual(jitted_output.shape, output.shape) # We need to override this test because in ViT, the seq_len equals the number of patches + 1 - # we compute that here def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) - num_patches = (config.image_size // config.patch_size) ** 2 - seq_length = num_patches + 1 # we add 1 for the [CLS] token + + seq_length = self.model_tester.expected_seq_length outputs = model(**self._prepare_for_class(inputs_dict, model_class)) hidden_states = outputs.hidden_states diff --git a/tests/vit/test_modeling_tf_vit.py b/tests/vit/test_modeling_tf_vit.py index f40580d733a..9ad64e82370 100644 --- a/tests/vit/test_modeling_tf_vit.py +++ b/tests/vit/test_modeling_tf_vit.py @@ -32,7 +32,6 @@ if is_tf_available(): import tensorflow as tf from transformers import TFViTForImageClassification, TFViTModel - from transformers.models.vit.modeling_tf_vit import to_2tuple if is_vision_available(): @@ -81,6 +80,10 @@ class TFViTModelTester: self.initializer_range = initializer_range self.scope = scope + # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -111,20 +114,18 @@ class TFViTModelTester: def create_and_check_model(self, config, pixel_values, labels): model = TFViTModel(config=config) result = model(pixel_values, training=False) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) # Test with an image with different size than the one specified in config. image_size = self.image_size // 2 pixel_values = pixel_values[:, :, :image_size, :image_size] result = model(pixel_values, interpolate_pos_encoding=True, training=False) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(image_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + expected_seq_length = (image_size // self.patch_size) ** 2 + 1 + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, expected_seq_length, self.hidden_size) + ) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size @@ -210,12 +211,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): config.use_cache = True # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 1 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) @@ -228,12 +224,8 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): model = tf.keras.models.load_model(saved_model_dir) outputs = model(class_inputs_dict) - if self.is_encoder_decoder: - output_hidden_states = outputs["encoder_hidden_states"] - output_attentions = outputs["encoder_attentions"] - else: - output_hidden_states = outputs["hidden_states"] - output_attentions = outputs["attentions"] + output_hidden_states = outputs["hidden_states"] + output_attentions = outputs["attentions"] self.assertEqual(len(outputs), num_out) @@ -250,7 +242,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(output_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, seq_len, seq_len], ) def test_attention_outputs(self): @@ -258,12 +250,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): config.return_dict = True # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 1 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -271,7 +258,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): config.return_dict = True model = model_class(config) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config @@ -279,12 +266,12 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): config.output_attentions = True model = model_class(config) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, seq_len, seq_len], ) out_len = len(outputs) @@ -294,20 +281,14 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): model = model_class(config) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + self.assertEqual(out_len + 1, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, seq_len, seq_len], ) def test_hidden_states_output(self): @@ -316,7 +297,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 @@ -324,10 +305,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): self.assertEqual(len(hidden_states), expected_num_layers) # ViT has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = num_patches + 1 + seq_length = self.model_tester.expected_seq_length self.assertListEqual( list(hidden_states[0].shape[-2:]), diff --git a/tests/vit/test_modeling_vit.py b/tests/vit/test_modeling_vit.py index db304aa815c..117815fa6db 100644 --- a/tests/vit/test_modeling_vit.py +++ b/tests/vit/test_modeling_vit.py @@ -31,7 +31,7 @@ if is_torch_available(): from torch import nn from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel - from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -59,7 +59,6 @@ class ViTModelTester: attention_probs_dropout_prob=0.1, type_sequence_label_size=10, initializer_range=0.02, - num_labels=3, scope=None, encoder_stride=2, ): @@ -82,6 +81,10 @@ class ViTModelTester: self.scope = scope self.encoder_stride = encoder_stride + # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.expected_seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -115,11 +118,9 @@ class ViTModelTester: model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size) + ) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size @@ -201,16 +202,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 1 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + seq_len = self.model_tester.expected_seq_length for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True @@ -221,7 +213,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config @@ -232,19 +224,13 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) out_len = len(outputs) # Check attention is always last and order is fine @@ -256,27 +242,15 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + self.assertEqual(out_len + 1, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): @@ -287,22 +261,16 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 ) self.assertEqual(len(hidden_states), expected_num_layers) - # ViT has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = num_patches + 1 - self.assertListEqual( list(hidden_states[0].shape[-2:]), - [seq_length, self.model_tester.hidden_size], + [self.model_tester.expected_seq_length, self.model_tester.hidden_size], ) config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()