mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Tensorflow Swin model (#16988)
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
6cb7187324
commit
f6a6388972
@ -256,7 +256,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| Swin | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Swin | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
@ -14,22 +14,22 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## Overview
|
||||
|
||||
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
|
||||
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
|
||||
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
|
||||
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
|
||||
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
|
||||
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
|
||||
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
|
||||
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
|
||||
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
|
||||
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
|
||||
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
|
||||
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
|
||||
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
|
||||
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
|
||||
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
|
||||
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
|
||||
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
|
||||
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
|
||||
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
|
||||
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
|
||||
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
|
||||
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
|
||||
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
|
||||
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
|
||||
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
|
||||
The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.*
|
||||
|
||||
Tips:
|
||||
@ -38,11 +38,11 @@ Tips:
|
||||
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Swin Transformer architecture. Taken from the <a href="https://arxiv.org/abs/2102.03334">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03>). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03>). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
|
||||
|
||||
|
||||
## SwinConfig
|
||||
@ -63,4 +63,19 @@ This model was contributed by [novice03](https://huggingface.co/novice03>). The
|
||||
## SwinForImageClassification
|
||||
|
||||
[[autodoc]] transformers.SwinForImageClassification
|
||||
- forward
|
||||
- forward
|
||||
|
||||
## TFSwinModel
|
||||
|
||||
[[autodoc]] TFSwinModel
|
||||
- call
|
||||
|
||||
## TFSwinForMaskedImageModeling
|
||||
|
||||
[[autodoc]] TFSwinForMaskedImageModeling
|
||||
- call
|
||||
|
||||
## TFSwinForImageClassification
|
||||
|
||||
[[autodoc]] transformers.TFSwinForImageClassification
|
||||
- call
|
||||
|
@ -1841,6 +1841,7 @@ else:
|
||||
[
|
||||
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -2208,6 +2209,15 @@ else:
|
||||
"TFSpeech2TextPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.swin"].extend(
|
||||
[
|
||||
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSwinForImageClassification",
|
||||
"TFSwinForMaskedImageModeling",
|
||||
"TFSwinModel",
|
||||
"TFSwinPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.t5"].extend(
|
||||
[
|
||||
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4071,6 +4081,7 @@ if TYPE_CHECKING:
|
||||
from .models.auto import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -4363,6 +4374,13 @@ if TYPE_CHECKING:
|
||||
TFSpeech2TextModel,
|
||||
TFSpeech2TextPreTrainedModel,
|
||||
)
|
||||
from .models.swin import (
|
||||
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSwinForImageClassification,
|
||||
TFSwinForMaskedImageModeling,
|
||||
TFSwinModel,
|
||||
TFSwinPreTrainedModel,
|
||||
)
|
||||
from .models.t5 import (
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFT5EncoderModel,
|
||||
|
@ -101,6 +101,7 @@ else:
|
||||
_import_structure["modeling_tf_auto"] = [
|
||||
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -238,6 +239,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_tf_auto import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
|
@ -65,6 +65,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("roberta", "TFRobertaModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("swin", "TFSwinModel"),
|
||||
("t5", "TFT5Model"),
|
||||
("tapas", "TFTapasModel"),
|
||||
("transfo-xl", "TFTransfoXLModel"),
|
||||
@ -159,11 +160,18 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("swin", "TFSwinForMaskedImageModeling"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("convnext", "TFConvNextForImageClassification"),
|
||||
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
||||
("swin", "TFSwinForImageClassification"),
|
||||
("vit", "TFViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
@ -349,6 +357,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
@ -409,6 +420,15 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
||||
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForMaskedImageModeling = auto_class_update(
|
||||
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelForImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -40,6 +40,19 @@ else:
|
||||
"SwinPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_swin"] = [
|
||||
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSwinForImageClassification",
|
||||
"TFSwinForMaskedImageModeling",
|
||||
"TFSwinModel",
|
||||
"TFSwinPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
|
||||
@ -58,6 +71,19 @@ if TYPE_CHECKING:
|
||||
SwinPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_swin import (
|
||||
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSwinForImageClassification,
|
||||
TFSwinForMaskedImageModeling,
|
||||
TFSwinModel,
|
||||
TFSwinPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
@ -1068,7 +1068,6 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# Reshape to (batch_size, num_channels, height, width)
|
||||
sequence_output = sequence_output.transpose(1, 2)
|
||||
batch_size, num_channels, sequence_length = sequence_output.shape
|
||||
|
1408
src/transformers/models/swin/modeling_tf_swin.py
Normal file
1408
src/transformers/models/swin/modeling_tf_swin.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -261,6 +261,9 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
@ -1887,6 +1890,37 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFSwinForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSwinForMaskedImageModeling(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSwinModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSwinPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -286,56 +286,76 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
|
||||
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# Swin has a different seq_length
|
||||
patch_size = to_2tuple(config.patch_size)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# Swin has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
def test_hidden_states_output_with_padding(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.patch_size = 3
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(config.patch_size)
|
||||
|
||||
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
||||
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
376
tests/models/swin/test_modeling_tf_swin.py
Normal file
376
tests/models/swin/test_modeling_tf_swin.py
Normal file
@ -0,0 +1,376 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TF 2.0 Swin model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import SwinConfig
|
||||
from transformers.testing_utils import require_tf, require_vision, slow
|
||||
from transformers.utils import cached_property, is_tf_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.models.swin.modeling_tf_swin import (
|
||||
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSwinForImageClassification,
|
||||
TFSwinForMaskedImageModeling,
|
||||
TFSwinModel,
|
||||
to_2tuple,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class TFSwinModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=32,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
embed_dim=16,
|
||||
depths=[1, 2, 1],
|
||||
num_heads=[2, 2, 4],
|
||||
window_size=2,
|
||||
mlp_ratio=2.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
drop_path_rate=0.1,
|
||||
hidden_act="gelu",
|
||||
use_absolute_embeddings=False,
|
||||
patch_norm=True,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
is_training=True,
|
||||
scope=None,
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
) -> None:
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.depths = depths
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.hidden_act = hidden_act
|
||||
self.use_absolute_embeddings = use_absolute_embeddings
|
||||
self.patch_norm = patch_norm
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.is_training = is_training
|
||||
self.scope = scope
|
||||
self.use_labels = use_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return SwinConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
embed_dim=self.embed_dim,
|
||||
depths=self.depths,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=self.qkv_bias,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
drop_path_rate=self.drop_path_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
use_absolute_embeddings=self.use_absolute_embeddings,
|
||||
path_norm=self.patch_norm,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFSwinModel(config=config)
|
||||
result = model(pixel_values)
|
||||
|
||||
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
|
||||
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = TFSwinForImageClassification(config)
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
TFSwinModel,
|
||||
TFSwinForImageClassification,
|
||||
TFSwinForMaskedImageModeling,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFSwinModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SwinConfig, embed_dim=37)
|
||||
|
||||
def test_config(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
return
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Swin does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
expected_num_attentions = len(self.model_tester.depths)
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
window_size_squared = config.window_size**2
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
else:
|
||||
# also another +1 for reshaped_hidden_states
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
|
||||
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# Swin has a different seq_length
|
||||
patch_size = to_2tuple(config.patch_size)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
|
||||
reshaped_hidden_states = tf.reshape(reshaped_hidden_states[0], (batch_size, num_channels, height * width))
|
||||
reshaped_hidden_states = tf.transpose(reshaped_hidden_states, (0, 2, 1))
|
||||
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
def test_inputs_requiring_padding(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.patch_size = 3
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(config.patch_size)
|
||||
|
||||
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
||||
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFSwinModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_tf
|
||||
class TFSwinModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return (
|
||||
AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = TFSwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
# forward pass
|
||||
outputs = model(inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
expected_slice = tf.constant([-0.0948, -0.6454, -0.0921])
|
||||
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user