mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Improve Swin for VisionEncoderDecoder (#16070)
* Add Swin2Bart test * Fix swin tests Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
0a057201a9
commit
a7aca42fc4
@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig):
|
||||
|
||||
attribute_map = {
|
||||
"num_attention_heads": "num_heads",
|
||||
"num_hidden_layers": "num_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -141,4 +142,4 @@ class SwinConfig(PretrainedConfig):
|
||||
self.encoder_stride = encoder_stride
|
||||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = embed_dim * 8
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
|
@ -56,8 +56,8 @@ class SwinModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
embed_dim=16,
|
||||
depths=[1],
|
||||
num_heads=[2],
|
||||
depths=[1, 2, 1],
|
||||
num_heads=[2, 2, 4],
|
||||
window_size=2,
|
||||
mlp_ratio=2.0,
|
||||
qkv_bias=True,
|
||||
@ -73,7 +73,7 @@ class SwinModelTester:
|
||||
scope=None,
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=2,
|
||||
encoder_stride=8,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -139,8 +139,7 @@ class SwinModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
|
||||
expected_seq_len = (config.image_size // config.patch_size) ** 2
|
||||
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
|
||||
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
@ -22,8 +22,10 @@ from datasets import load_dataset
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ..bart.test_modeling_bart import BartModelTester
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
from ..deit.test_modeling_deit import DeiTModelTester
|
||||
from ..swin.test_modeling_swin import SwinModelTester
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
|
||||
from ..vit.test_modeling_vit import ViTModelTester
|
||||
@ -35,8 +37,10 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BertLMHeadModel,
|
||||
DeiTModel,
|
||||
SwinModel,
|
||||
TrOCRForCausalLM,
|
||||
VisionEncoderDecoderConfig,
|
||||
VisionEncoderDecoderModel,
|
||||
@ -514,6 +518,90 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = SwinModel(config).eval()
|
||||
decoder_model = BartForCausalLM(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32)
|
||||
model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
decoder_config, decoder_inputs_dict = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_config": decoder_config,
|
||||
**decoder_inputs_dict,
|
||||
}
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels=None,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
# in Swin, the seq_len equals:
|
||||
seq_len = encoder_model.config.window_size**2
|
||||
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len))
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len),
|
||||
)
|
||||
|
||||
# there are no published pretrained BART-causal checkpoints for now
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
|
Loading…
Reference in New Issue
Block a user