Improve Swin for VisionEncoderDecoder (#16070)

* Add Swin2Bart test

* Fix swin tests

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-03-15 09:59:48 +01:00 committed by GitHub
parent 0a057201a9
commit a7aca42fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 6 deletions

View File

@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig):
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
@ -141,4 +142,4 @@ class SwinConfig(PretrainedConfig):
self.encoder_stride = encoder_stride
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = embed_dim * 8
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))

View File

@ -56,8 +56,8 @@ class SwinModelTester:
patch_size=2,
num_channels=3,
embed_dim=16,
depths=[1],
num_heads=[2],
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
mlp_ratio=2.0,
qkv_bias=True,
@ -73,7 +73,7 @@ class SwinModelTester:
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=2,
encoder_stride=8,
):
self.parent = parent
self.batch_size = batch_size
@ -139,8 +139,7 @@ class SwinModelTester:
model.eval()
result = model(pixel_values)
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))

View File

@ -22,8 +22,10 @@ from datasets import load_dataset
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from ..vit.test_modeling_vit import ViTModelTester
@ -35,8 +37,10 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
BartForCausalLM,
BertLMHeadModel,
DeiTModel,
SwinModel,
TrOCRForCausalLM,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel,
@ -514,6 +518,90 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
}
@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = SwinModel(config).eval()
decoder_model = BartForCausalLM(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32)
model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs
decoder_config, decoder_inputs_dict = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
**decoder_inputs_dict,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
# in Swin, the seq_len equals:
seq_len = encoder_model.config.window_size**2
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len),
)
# there are no published pretrained BART-causal checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):