Add Tensorflow Swin model (#16988)

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
amyeroberts 2022-05-16 22:19:53 +01:00 committed by GitHub
parent 6cb7187324
commit f6a6388972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1977 additions and 59 deletions

View File

@ -256,7 +256,7 @@ Flax), PyTorch, and/or TensorFlow.
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| Swin | ❌ | ❌ | ✅ | | ❌ |
| Swin | ❌ | ❌ | ✅ | | ❌ |
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |

View File

@ -14,22 +14,22 @@ specific language governing permissions and limitations under the License.
## Overview
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
The abstract from the paper is the following:
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.*
Tips:
@ -38,11 +38,11 @@ Tips:
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/>
alt="drawing" width="600"/>
<small> Swin Transformer architecture. Taken from the <a href="https://arxiv.org/abs/2102.03334">original paper</a>.</small>
This model was contributed by [novice03](https://huggingface.co/novice03>). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
This model was contributed by [novice03](https://huggingface.co/novice03>). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
## SwinConfig
@ -63,4 +63,19 @@ This model was contributed by [novice03](https://huggingface.co/novice03>). The
## SwinForImageClassification
[[autodoc]] transformers.SwinForImageClassification
- forward
- forward
## TFSwinModel
[[autodoc]] TFSwinModel
- call
## TFSwinForMaskedImageModeling
[[autodoc]] TFSwinForMaskedImageModeling
- call
## TFSwinForImageClassification
[[autodoc]] transformers.TFSwinForImageClassification
- call

View File

@ -1841,6 +1841,7 @@ else:
[
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
@ -2208,6 +2209,15 @@ else:
"TFSpeech2TextPreTrainedModel",
]
)
_import_structure["models.swin"].extend(
[
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwinForImageClassification",
"TFSwinForMaskedImageModeling",
"TFSwinModel",
"TFSwinPreTrainedModel",
]
)
_import_structure["models.t5"].extend(
[
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4071,6 +4081,7 @@ if TYPE_CHECKING:
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
@ -4363,6 +4374,13 @@ if TYPE_CHECKING:
TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel,
)
from .models.swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification,
TFSwinForMaskedImageModeling,
TFSwinModel,
TFSwinPreTrainedModel,
)
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,

View File

@ -101,6 +101,7 @@ else:
_import_structure["modeling_tf_auto"] = [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
@ -238,6 +239,7 @@ if TYPE_CHECKING:
from .modeling_tf_auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,

View File

@ -65,6 +65,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
("t5", "TFT5Model"),
("tapas", "TFTapasModel"),
("transfo-xl", "TFTransfoXLModel"),
@ -159,11 +160,18 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("swin", "TFSwinForMaskedImageModeling"),
]
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
)
@ -349,6 +357,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
@ -409,6 +420,15 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass):
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
TFAutoModelForMaskedImageModeling = auto_class_update(
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
)
class TFAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING

View File

@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@ -40,6 +40,19 @@ else:
"SwinPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_swin"] = [
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwinForImageClassification",
"TFSwinForMaskedImageModeling",
"TFSwinModel",
"TFSwinPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
@ -58,6 +71,19 @@ if TYPE_CHECKING:
SwinPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification,
TFSwinForMaskedImageModeling,
TFSwinModel,
TFSwinPreTrainedModel,
)
else:
import sys

View File

@ -1068,7 +1068,6 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
)
sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape

File diff suppressed because it is too large Load Diff

View File

@ -261,6 +261,9 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = None
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
TF_MODEL_FOR_MASKED_LM_MAPPING = None
@ -1887,6 +1890,37 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSwinForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwinForMaskedImageModeling(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwinModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwinPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -286,56 +286,76 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Swin has a different seq_length
patch_size = to_2tuple(config.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Swin has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = to_2tuple(self.model_tester.image_size)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
def test_hidden_states_output_with_padding(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.patch_size = 3
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(config.patch_size)
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -0,0 +1,376 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TF 2.0 Swin model. """
import inspect
import unittest
import numpy as np
from transformers import SwinConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers.models.swin.modeling_tf_swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification,
TFSwinForMaskedImageModeling,
TFSwinModel,
to_2tuple,
)
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class TFSwinModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
patch_size=2,
num_channels=3,
embed_dim=16,
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
mlp_ratio=2.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
is_training=True,
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
) -> None:
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.patch_norm = patch_norm
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.is_training = is_training
self.scope = scope
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return SwinConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
embed_dim=self.embed_dim,
depths=self.depths,
num_heads=self.num_heads,
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
drop_path_rate=self.drop_path_rate,
hidden_act=self.hidden_act,
use_absolute_embeddings=self.use_absolute_embeddings,
path_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFSwinModel(config=config)
result = model(pixel_values)
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFSwinForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFSwinModel,
TFSwinForImageClassification,
TFSwinForMaskedImageModeling,
)
if is_tf_available()
else ()
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFSwinModelTester(self)
self.config_tester = ConfigTester(self, config_class=SwinConfig, embed_dim=37)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Swin does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
window_size_squared = config.window_size**2
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
else:
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), expected_num_attentions)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Swin has a different seq_length
patch_size = to_2tuple(config.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = tf.reshape(reshaped_hidden_states[0], (batch_size, num_channels, height * width))
reshaped_hidden_states = tf.transpose(reshaped_hidden_states, (0, 2, 1))
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = to_2tuple(self.model_tester.image_size)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
def test_inputs_requiring_padding(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.patch_size = 3
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(config.patch_size)
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSwinModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_vision
@require_tf
class TFSwinModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFSwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
feature_extractor = self.default_feature_extractor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.0948, -0.6454, -0.0921])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))