[ViT, BEiT, DeiT, DPT] Improve code (#16799)

* Improve code

* Fix bugs

* Fix another bug

* Clean up DTP as well

* Update DPT model outputs

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-04-18 15:25:08 +02:00 committed by GitHub
parent 3785f4665a
commit d3c9d0e55f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 130 additions and 226 deletions

View File

@ -542,7 +542,8 @@ class DeiTModel(DeiTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
@ -662,7 +663,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput(
@ -775,7 +776,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
@ -882,7 +883,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
logits = (cls_logits + distillation_logits) / 2
if not return_dict:
output = (logits, cls_logits, distillation_logits) + outputs[2:]
output = (logits, cls_logits, distillation_logits) + outputs[1:]
return output
return DeiTForImageClassificationWithTeacherOutput(

View File

@ -750,7 +750,8 @@ class DPTModel(DPTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
@ -938,7 +939,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
@ -956,9 +957,9 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
if not return_dict:
if output_hidden_states:
output = (predicted_depth,) + outputs[2:]
output = (predicted_depth,) + outputs[1:]
else:
output = (predicted_depth,) + outputs[3:]
output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return DepthEstimatorOutput(
@ -1083,7 +1084,7 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
@ -1120,9 +1121,9 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[3:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SemanticSegmenterOutput(

View File

@ -585,7 +585,8 @@ class ViTModel(ViTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
@ -706,7 +707,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput(
@ -798,8 +799,9 @@ class ViTForImageClassification(ViTPreTrainedModel):
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(

View File

@ -41,7 +41,7 @@ if is_torch_available():
BeitForSemanticSegmentation,
BeitModel,
)
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@ -96,6 +96,10 @@ class BeitModelTester:
self.out_indices = out_indices
self.num_labels = num_labels
# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -132,22 +136,16 @@ class BeitModelTester:
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size
@ -312,16 +310,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
# BEiT has a different seq_length
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -332,7 +322,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
@ -349,7 +339,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)
@ -369,7 +359,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
def test_hidden_states_output(self):
@ -381,7 +371,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
@ -389,10 +379,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers)
# BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),

View File

@ -75,6 +75,10 @@ class FlaxBeitModelTester(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -104,20 +108,14 @@ class FlaxBeitModelTester(unittest.TestCase):
model = FlaxBeitModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values)
# expected sequence length = num_patches
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@ -151,13 +149,11 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests()
# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -209,7 +205,7 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
# We neeed to override this test because Beit expects pixel_values instead of input_ids
# We need to override this test because Beit expects pixel_values instead of input_ids
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -234,12 +230,10 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token
seq_length = self.model_tester.expected_seq_length
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states

View File

@ -41,7 +41,7 @@ if is_torch_available():
DeiTForMaskedImageModeling,
DeiTModel,
)
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@ -92,6 +92,10 @@ class DeiTModelTester:
self.scope = scope
self.encoder_stride = encoder_stride
# in DeiT, the expected seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 2
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -125,11 +129,9 @@ class DeiTModelTester:
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 2 (we add 2 for the [CLS] and distillation tokens)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 2, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@ -212,16 +214,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in DeiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 2
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -232,7 +225,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
@ -243,19 +236,13 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)
# Check attention is always last and order is fine
@ -267,27 +254,15 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
@ -298,18 +273,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# DeiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 2
seq_length = self.model_tester.expected_seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),

View File

@ -81,6 +81,9 @@ class DPTModelTester:
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
# expected sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -115,9 +118,9 @@ class DPTModelTester:
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (config.image_size // config.patch_size) ** 2
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
@ -206,8 +209,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True
# in DPT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (config.image_size // config.patch_size) ** 2
seq_len = num_patches + 1
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -274,8 +276,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers)
# DPT has a different seq_length
num_patches = (config.image_size // config.patch_size) ** 2
seq_len = num_patches + 1
seq_len = self.model_tester.expected_seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),

View File

@ -67,6 +67,10 @@ class FlaxViTModelTester(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -120,13 +124,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests()
# We need to override this test because in ViT, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -203,12 +205,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
# We need to override this test because in ViT, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token
seq_length = self.model_tester.expected_seq_length
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states

View File

@ -32,7 +32,6 @@ if is_tf_available():
import tensorflow as tf
from transformers import TFViTForImageClassification, TFViTModel
from transformers.models.vit.modeling_tf_vit import to_2tuple
if is_vision_available():
@ -81,6 +80,10 @@ class TFViTModelTester:
self.initializer_range = initializer_range
self.scope = scope
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -111,20 +114,18 @@ class TFViTModelTester:
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTModel(config=config)
result = model(pixel_values, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(image_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
expected_seq_length = (image_size // self.patch_size) ** 2 + 1
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, expected_seq_length, self.hidden_size)
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@ -210,12 +211,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.use_cache = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
@ -228,12 +224,8 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
if self.is_encoder_decoder:
output_hidden_states = outputs["encoder_hidden_states"]
output_attentions = outputs["encoder_attentions"]
else:
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]
self.assertEqual(len(outputs), num_out)
@ -250,7 +242,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
def test_attention_outputs(self):
@ -258,12 +250,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -271,7 +258,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
@ -279,12 +266,12 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)
@ -294,20 +281,14 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
def test_hidden_states_output(self):
@ -316,7 +297,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
@ -324,10 +305,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),

View File

@ -31,7 +31,7 @@ if is_torch_available():
from torch import nn
from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@ -59,7 +59,6 @@ class ViTModelTester:
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
@ -82,6 +81,10 @@ class ViTModelTester:
self.scope = scope
self.encoder_stride = encoder_stride
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -115,11 +118,9 @@ class ViTModelTester:
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@ -201,16 +202,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
seq_len = self.model_tester.expected_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
@ -221,7 +213,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
@ -232,19 +224,13 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)
# Check attention is always last and order is fine
@ -256,27 +242,15 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
@ -287,22 +261,16 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
[self.model_tester.expected_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()