mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Maskformer] Add MaskFormerSwin backbone (#20344)
* First draft * Fix backwards compatibility * More fixes * More fixes * Make backbone more general * Improve backbone * Improve test * Fix config checkpoint * Address comments * Use model_type * Address more comments * Fix special model names * Remove MaskFormerSwinModel and MaskFormerSwinPreTrainedModel from main init * Fix typo * Update backbone * Apply suggestion Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
955780d3ab
commit
6dc884abc8
@ -288,6 +288,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Marian | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| MarkupLM | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MaskFormerSwin | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Megatron-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
@ -295,7 +295,7 @@ _import_structure = {
|
||||
"MarkupLMProcessor",
|
||||
"MarkupLMTokenizer",
|
||||
],
|
||||
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
|
||||
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"],
|
||||
"models.mbart": ["MBartConfig"],
|
||||
"models.mbart50": [],
|
||||
"models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"],
|
||||
@ -1622,6 +1622,7 @@ else:
|
||||
"MaskFormerForInstanceSegmentation",
|
||||
"MaskFormerModel",
|
||||
"MaskFormerPreTrainedModel",
|
||||
"MaskFormerSwinBackbone",
|
||||
]
|
||||
)
|
||||
_import_structure["models.markuplm"].extend(
|
||||
@ -3479,7 +3480,7 @@ if TYPE_CHECKING:
|
||||
MarkupLMProcessor,
|
||||
MarkupLMTokenizer,
|
||||
)
|
||||
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
|
||||
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig
|
||||
from .models.mbart import MBartConfig
|
||||
from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor
|
||||
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
|
||||
@ -4573,6 +4574,7 @@ if TYPE_CHECKING:
|
||||
MaskFormerForInstanceSegmentation,
|
||||
MaskFormerModel,
|
||||
MaskFormerPreTrainedModel,
|
||||
MaskFormerSwinBackbone,
|
||||
)
|
||||
from .models.mbart import (
|
||||
MBartForCausalLM,
|
||||
|
@ -98,6 +98,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("marian", "MarianConfig"),
|
||||
("markuplm", "MarkupLMConfig"),
|
||||
("maskformer", "MaskFormerConfig"),
|
||||
("maskformer-swin", "MaskFormerSwinConfig"),
|
||||
("mbart", "MBartConfig"),
|
||||
("mctct", "MCTCTConfig"),
|
||||
("megatron-bert", "MegatronBertConfig"),
|
||||
@ -395,6 +396,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("marian", "Marian"),
|
||||
("markuplm", "MarkupLM"),
|
||||
("maskformer", "MaskFormer"),
|
||||
("maskformer-swin", "MaskFormerSwin"),
|
||||
("mbart", "mBART"),
|
||||
("mbart50", "mBART-50"),
|
||||
("mctct", "M-CTC-T"),
|
||||
@ -491,6 +493,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("data2vec-text", "data2vec"),
|
||||
("data2vec-vision", "data2vec"),
|
||||
("donut-swin", "donut"),
|
||||
("maskformer-swin", "maskformer"),
|
||||
("xclip", "x_clip"),
|
||||
]
|
||||
)
|
||||
|
@ -97,6 +97,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("marian", "MarianModel"),
|
||||
("markuplm", "MarkupLMModel"),
|
||||
("maskformer", "MaskFormerModel"),
|
||||
("maskformer-swin", "MaskFormerSwinModel"),
|
||||
("mbart", "MBartModel"),
|
||||
("mctct", "MCTCTModel"),
|
||||
("megatron-bert", "MegatronBertModel"),
|
||||
@ -847,6 +848,7 @@ _MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Backbone mapping
|
||||
("maskformer-swin", "MaskFormerSwinBackbone"),
|
||||
("resnet", "ResNetBackbone"),
|
||||
]
|
||||
)
|
||||
|
@ -20,7 +20,10 @@ from typing import TYPE_CHECKING
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"]}
|
||||
_import_structure = {
|
||||
"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
|
||||
"configuration_maskformer_swin": ["MaskFormerSwinConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
@ -43,9 +46,15 @@ else:
|
||||
"MaskFormerModel",
|
||||
"MaskFormerPreTrainedModel",
|
||||
]
|
||||
_import_structure["modeling_maskformer_swin"] = [
|
||||
"MaskFormerSwinBackbone",
|
||||
"MaskFormerSwinModel",
|
||||
"MaskFormerSwinPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
@ -66,6 +75,11 @@ if TYPE_CHECKING:
|
||||
MaskFormerModel,
|
||||
MaskFormerPreTrainedModel,
|
||||
)
|
||||
from .modeling_maskformer_swin import (
|
||||
MaskFormerSwinBackbone,
|
||||
MaskFormerSwinModel,
|
||||
MaskFormerSwinPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
@ -0,0 +1,149 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MaskFormer Swin Transformer model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MaskFormerSwinConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
|
||||
a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Swin
|
||||
[microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224)
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 4):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
embed_dim (`int`, *optional*, defaults to 96):
|
||||
Dimensionality of patch embedding.
|
||||
depths (`List[int]`, *optional*, defaults to `[2, 2, 6, 2]`):
|
||||
Depth of each layer in the Transformer encoder.
|
||||
num_heads (`List[int]`, *optional*, defaults to `[3, 6, 12, 24]`):
|
||||
Number of attention heads in each layer of the Transformer encoder.
|
||||
window_size (`int`, *optional*, defaults to 7):
|
||||
Size of windows.
|
||||
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||
Ratio of MLP hidden dimensionality to embedding dimensionality.
|
||||
qkv_bias (`bool`, *optional*, defaults to True):
|
||||
Whether or not a learnable bias should be added to the queries, keys and values.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings and encoder.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||
Stochastic depth rate.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
use_absolute_embeddings (`bool`, *optional*, defaults to False):
|
||||
Whether or not to add absolute position embeddings to the patch embeddings.
|
||||
patch_norm (`bool`, *optional*, defaults to True):
|
||||
Whether or not to add layer normalization after patch embedding.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MaskFormerSwinConfig, MaskFormerSwinModel
|
||||
|
||||
>>> # Initializing a microsoft/swin-tiny-patch4-window7-224 style configuration
|
||||
>>> configuration = MaskFormerSwinConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration
|
||||
>>> model = MaskFormerSwinModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "maskformer-swin"
|
||||
|
||||
attribute_map = {
|
||||
"num_attention_heads": "num_heads",
|
||||
"num_hidden_layers": "num_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=224,
|
||||
patch_size=4,
|
||||
num_channels=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
drop_path_rate=0.1,
|
||||
hidden_act="gelu",
|
||||
use_absolute_embeddings=False,
|
||||
patch_norm=True,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
out_features=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.depths = depths
|
||||
self.num_layers = len(depths)
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.hidden_act = hidden_act
|
||||
self.use_absolute_embeddings = use_absolute_embeddings
|
||||
self.path_norm = patch_norm
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
self.out_features = out_features
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch MaskFormer model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
@ -25,12 +24,12 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from transformers import AutoBackbone
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@ -40,8 +39,8 @@ from ...utils import (
|
||||
requires_backends,
|
||||
)
|
||||
from ..detr import DetrConfig
|
||||
from ..swin import SwinConfig
|
||||
from .configuration_maskformer import MaskFormerConfig
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
@ -60,71 +59,6 @@ MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
|
||||
"""
|
||||
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||
Last layer hidden-state after a mean pooling operation.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
|
||||
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
|
||||
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
|
||||
`forward` method.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskFormerSwinBaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Class for SwinEncoder's outputs.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
|
||||
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
|
||||
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
|
||||
method.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput
|
||||
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
||||
@ -471,714 +405,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
|
||||
return loss / height_and_width
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.window_partition
|
||||
def window_partition(input_feature, window_size):
|
||||
"""
|
||||
Partitions the given input into windows.
|
||||
"""
|
||||
batch_size, height, width, num_channels = input_feature.shape
|
||||
input_feature = input_feature.view(
|
||||
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
||||
)
|
||||
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
||||
return windows
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.window_reverse
|
||||
def window_reverse(windows, window_size, height, width):
|
||||
"""
|
||||
Merges windows to produce higher resolution features.
|
||||
"""
|
||||
num_channels = windows.shape[-1]
|
||||
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
|
||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
|
||||
return windows
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.drop_path
|
||||
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
||||
"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||
argument.
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return input
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = input.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class MaskFormerSwinEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the patch and position embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.patch_grid = self.patch_embeddings.grid_size
|
||||
|
||||
if config.use_absolute_embeddings:
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
|
||||
else:
|
||||
self.position_embeddings = None
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
embeddings = self.norm(embeddings)
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings, output_dimensions
|
||||
|
||||
|
||||
class MaskFormerSwinPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding, including padding.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||||
|
||||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def maybe_pad(self, pixel_values, height, width):
|
||||
if width % self.patch_size[1] != 0:
|
||||
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
if height % self.patch_size[0] != 0:
|
||||
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
embeddings_flat = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings_flat, output_dimensions
|
||||
|
||||
|
||||
class MaskFormerSwinPatchMerging(nn.Module):
|
||||
"""
|
||||
Patch Merging Layer for maskformer model.
|
||||
|
||||
Args:
|
||||
input_resolution (`Tuple[int]`):
|
||||
Resolution of input feature.
|
||||
dim (`int`):
|
||||
Number of input channels.
|
||||
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
||||
Normalization layer class.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def maybe_pad(self, input_feature, width, height):
|
||||
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
||||
if should_pad:
|
||||
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
||||
input_feature = nn.functional.pad(input_feature, pad_values)
|
||||
|
||||
return input_feature
|
||||
|
||||
def forward(self, input_feature, input_dimensions):
|
||||
height, width = input_dimensions
|
||||
# `dim` is height * width
|
||||
batch_size, dim, num_channels = input_feature.shape
|
||||
|
||||
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
||||
# pad input to be disible by width and height, if needed
|
||||
input_feature = self.maybe_pad(input_feature, height, width)
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
||||
# batch_size height/2 width/2 4*num_channels
|
||||
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
|
||||
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
|
||||
|
||||
input_feature = self.norm(input_feature)
|
||||
input_feature = self.reduction(input_feature)
|
||||
|
||||
return input_feature
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinDropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinSelfAttention(nn.Module):
|
||||
def __init__(self, config, dim, num_heads, window_size):
|
||||
super().__init__()
|
||||
if dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = num_heads
|
||||
self.attention_head_size = int(dim / num_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.window_size = (
|
||||
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||
)
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||
)
|
||||
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
|
||||
mask_shape = attention_mask.shape[0]
|
||||
attention_scores = attention_scores.view(
|
||||
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
|
||||
)
|
||||
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinSelfOutput(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinAttention(nn.Module):
|
||||
def __init__(self, config, dim, num_heads, window_size):
|
||||
super().__init__()
|
||||
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
|
||||
self.output = MaskFormerSwinSelfOutput(config, dim)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinIntermediate(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinOutput(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MaskFormerSwinLayer(nn.Module):
|
||||
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.shift_size = shift_size
|
||||
self.window_size = config.window_size
|
||||
self.input_resolution = input_resolution
|
||||
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
|
||||
self.drop_path = (
|
||||
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.intermediate = MaskFormerSwinIntermediate(config, dim)
|
||||
self.output = MaskFormerSwinOutput(config, dim)
|
||||
|
||||
def get_attn_mask(self, input_resolution):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
height, width = input_resolution
|
||||
img_mask = torch.zeros((1, height, width, 1))
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
width_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
count = 0
|
||||
for height_slice in height_slices:
|
||||
for width_slice in width_slices:
|
||||
img_mask[:, height_slice, width_slice, :] = count
|
||||
count += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
def maybe_pad(self, hidden_states, height, width):
|
||||
pad_left = pad_top = 0
|
||||
pad_rigth = (self.window_size - width % self.window_size) % self.window_size
|
||||
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
||||
pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom)
|
||||
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
||||
return hidden_states, pad_values
|
||||
|
||||
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
|
||||
height, width = input_dimensions
|
||||
batch_size, dim, channels = hidden_states.size()
|
||||
shortcut = hidden_states
|
||||
|
||||
hidden_states = self.layernorm_before(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||||
# pad hidden_states to multiples of window size
|
||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||
|
||||
_, height_pad, width_pad, _ = hidden_states.shape
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_hidden_states = hidden_states
|
||||
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask((height_pad, width_pad))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
||||
shifted_windows = window_reverse(
|
||||
attention_windows, self.window_size, height_pad, width_pad
|
||||
) # B height' width' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
attention_windows = shifted_windows
|
||||
|
||||
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||||
if was_padded:
|
||||
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
||||
|
||||
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
||||
|
||||
hidden_states = shortcut + self.drop_path(attention_windows)
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = hidden_states + self.output(layer_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MaskFormerSwinStage(nn.Module):
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
|
||||
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MaskFormerSwinLayer(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.pointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
height, width = input_dimensions
|
||||
for i, block_module in enumerate(self.blocks):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = block_hidden_states[0]
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(hidden_states, input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
return hidden_states, output_dimensions, all_hidden_states
|
||||
|
||||
|
||||
class MaskFormerSwinEncoder(nn.Module):
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
|
||||
def __init__(self, config, grid_size):
|
||||
super().__init__()
|
||||
self.num_layers = len(config.depths)
|
||||
self.config = config
|
||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MaskFormerSwinStage(
|
||||
config=config,
|
||||
dim=int(config.embed_dim * 2**i_layer),
|
||||
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
|
||||
depth=config.depths[i_layer],
|
||||
num_heads=config.num_heads[i_layer],
|
||||
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
|
||||
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
)
|
||||
for i_layer in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_input_dimensions = ()
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module), hidden_states, layer_head_mask
|
||||
)
|
||||
else:
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
)
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (layer_all_hidden_states,)
|
||||
|
||||
hidden_states = layer_hidden_states
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return MaskFormerSwinBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
hidden_states_spatial_dimensions=all_input_dimensions,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class MaskFormerSwinModel(nn.Module, ModuleUtilsMixin):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||||
|
||||
self.embeddings = MaskFormerSwinEmbeddings(config)
|
||||
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs.last_hidden_state
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
|
||||
pooled_output = None
|
||||
if self.pooler is not None:
|
||||
pooled_output = self.pooler(sequence_output.transpose(1, 2))
|
||||
pooled_output = torch.flatten(pooled_output, 1)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
|
||||
|
||||
return MaskFormerSwinModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrAttention
|
||||
class DetrAttention(nn.Module):
|
||||
"""
|
||||
@ -1923,52 +1149,6 @@ class MaskFormerLoss(nn.Module):
|
||||
return num_masks_pt
|
||||
|
||||
|
||||
class MaskFormerSwinTransformerBackbone(nn.Module):
|
||||
"""
|
||||
This class uses [`MaskFormerSwinModel`] to reshape its `hidden_states` from (`batch_size, sequence_length,
|
||||
hidden_size)` to (`batch_size, num_channels, height, width)`).
|
||||
|
||||
Args:
|
||||
config (`SwinConfig`):
|
||||
The configuration used by [`MaskFormerSwinModel`].
|
||||
"""
|
||||
|
||||
def __init__(self, config: SwinConfig):
|
||||
super().__init__()
|
||||
self.model = MaskFormerSwinModel(config)
|
||||
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(out_shape) for out_shape in self.outputs_shapes])
|
||||
|
||||
def forward(self, *args, **kwargs) -> List[Tensor]:
|
||||
output = self.model(*args, **kwargs, output_hidden_states=True)
|
||||
hidden_states_permuted: List[Tensor] = []
|
||||
# we need to reshape the hidden state to their original spatial dimensions
|
||||
# skipping the embeddings
|
||||
hidden_states: Tuple[Tuple[Tensor]] = output.hidden_states[1:]
|
||||
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
|
||||
spatial_dimensions: Tuple[Tuple[int, int]] = output.hidden_states_spatial_dimensions
|
||||
for i, (hidden_state, (height, width)) in enumerate(zip(hidden_states, spatial_dimensions)):
|
||||
norm = self.hidden_states_norms[i]
|
||||
# the last element corespond to the layer's last block output but before patch merging
|
||||
hidden_state_unpolled = hidden_state[-1]
|
||||
hidden_state_norm = norm(hidden_state_unpolled)
|
||||
# our pixel decoder (FPN) expect 3D tensors (features)
|
||||
batch_size, _, hidden_size = hidden_state_norm.shape
|
||||
# reshape our tensor "b (h w) d -> b d h w"
|
||||
hidden_state_permuted = (
|
||||
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
|
||||
)
|
||||
hidden_states_permuted.append(hidden_state_permuted)
|
||||
return hidden_states_permuted
|
||||
|
||||
@property
|
||||
def input_resolutions(self) -> List[int]:
|
||||
return [layer.input_resolution for layer in self.model.encoder.layers]
|
||||
|
||||
@property
|
||||
def outputs_shapes(self) -> List[int]:
|
||||
return [layer.dim for layer in self.model.encoder.layers]
|
||||
|
||||
|
||||
class MaskFormerFPNConvLayer(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
|
||||
"""
|
||||
@ -2065,7 +1245,7 @@ class MaskFormerPixelDecoder(nn.Module):
|
||||
def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
|
||||
"""
|
||||
Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
|
||||
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's feature into a Feature Pyramid
|
||||
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's features into a Feature Pyramid
|
||||
Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.
|
||||
|
||||
Args:
|
||||
@ -2075,13 +1255,15 @@ class MaskFormerPixelDecoder(nn.Module):
|
||||
The features (channels) of the target masks size \\C_{\epsilon}\\ in the paper.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
|
||||
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput:
|
||||
fpn_features: List[Tensor] = self.fpn(features)
|
||||
fpn_features = self.fpn(features)
|
||||
# we use the last feature map
|
||||
last_feature_projected = self.mask_projection(fpn_features[-1])
|
||||
|
||||
return MaskFormerPixelDecoderOutput(
|
||||
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
|
||||
)
|
||||
@ -2193,17 +1375,26 @@ class MaskFormerPixelLevelModule(nn.Module):
|
||||
The configuration used to instantiate this model.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = MaskFormerSwinTransformerBackbone(config.backbone_config)
|
||||
|
||||
# TODD: add method to load pretrained weights of backbone
|
||||
backbone_config = config.backbone_config
|
||||
if backbone_config.model_type == "swin":
|
||||
# for backwards compatibility
|
||||
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
|
||||
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
|
||||
self.encoder = AutoBackbone.from_config(backbone_config)
|
||||
|
||||
feature_channels = self.encoder.channels
|
||||
self.decoder = MaskFormerPixelDecoder(
|
||||
in_features=self.encoder.outputs_shapes[-1],
|
||||
in_features=feature_channels[-1],
|
||||
feature_size=config.fpn_feature_size,
|
||||
mask_feature_size=config.mask_feature_size,
|
||||
lateral_widths=self.encoder.outputs_shapes[:-1],
|
||||
lateral_widths=feature_channels[:-1],
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput:
|
||||
features: List[Tensor] = self.encoder(pixel_values)
|
||||
decoder_output: MaskFormerPixelDecoderOutput = self.decoder(features, output_hidden_states)
|
||||
features = self.encoder(pixel_values).feature_maps
|
||||
decoder_output = self.decoder(features, output_hidden_states)
|
||||
return MaskFormerPixelLevelModuleOutput(
|
||||
# the last feature is actually the output from the last layer
|
||||
encoder_last_hidden_state=features[-1],
|
||||
@ -2335,8 +1526,8 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, MaskFormerSwinEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
if isinstance(module, MaskFormerPixelLevelModule):
|
||||
module.encoder.gradient_checkpointing = value
|
||||
if isinstance(module, DetrDecoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@ -2350,7 +1541,7 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.pixel_level_module = MaskFormerPixelLevelModule(config)
|
||||
self.transformer_module = MaskFormerTransformerModule(
|
||||
in_features=self.pixel_level_module.encoder.outputs_shapes[-1], config=config
|
||||
in_features=self.pixel_level_module.encoder.channels[-1], config=config
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
@ -2407,15 +1598,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
|
||||
if pixel_mask is None:
|
||||
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
|
||||
|
||||
pixel_level_module_output: MaskFormerPixelLevelModuleOutput = self.pixel_level_module(
|
||||
pixel_values, output_hidden_states
|
||||
)
|
||||
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states)
|
||||
image_features = pixel_level_module_output.encoder_last_hidden_state
|
||||
pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state
|
||||
|
||||
transformer_module_output: DetrDecoderOutput = self.transformer_module(
|
||||
image_features, output_hidden_states, output_attentions
|
||||
)
|
||||
transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions)
|
||||
queries = transformer_module_output.last_hidden_state
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
925
src/transformers/models/maskformer/modeling_maskformer_swin.py
Normal file
925
src/transformers/models/maskformer/modeling_maskformer_swin.py
Normal file
@ -0,0 +1,925 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden
|
||||
states before downsampling, which is different from the default Swin Transformer."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import ModelOutput
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
|
||||
"""
|
||||
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||
Last layer hidden-state after a mean pooling operation.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
|
||||
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
|
||||
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
|
||||
`forward` method.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskFormerSwinBaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Class for SwinEncoder's outputs.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
|
||||
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
|
||||
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
|
||||
method.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.window_partition
|
||||
def window_partition(input_feature, window_size):
|
||||
"""
|
||||
Partitions the given input into windows.
|
||||
"""
|
||||
batch_size, height, width, num_channels = input_feature.shape
|
||||
input_feature = input_feature.view(
|
||||
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
||||
)
|
||||
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
||||
return windows
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.window_reverse
|
||||
def window_reverse(windows, window_size, height, width):
|
||||
"""
|
||||
Merges windows to produce higher resolution features.
|
||||
"""
|
||||
num_channels = windows.shape[-1]
|
||||
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
|
||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
|
||||
return windows
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.drop_path
|
||||
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
||||
"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||
argument.
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return input
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = input.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class MaskFormerSwinEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the patch and position embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.patch_grid = self.patch_embeddings.grid_size
|
||||
|
||||
if config.use_absolute_embeddings:
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
|
||||
else:
|
||||
self.position_embeddings = None
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
embeddings = self.norm(embeddings)
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings, output_dimensions
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
|
||||
class MaskFormerSwinPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||
Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def maybe_pad(self, pixel_values, height, width):
|
||||
if width % self.patch_size[1] != 0:
|
||||
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
if height % self.patch_size[0] != 0:
|
||||
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, output_dimensions
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
|
||||
class MaskFormerSwinPatchMerging(nn.Module):
|
||||
"""
|
||||
Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (`Tuple[int]`):
|
||||
Resolution of input feature.
|
||||
dim (`int`):
|
||||
Number of input channels.
|
||||
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
||||
Normalization layer class.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def maybe_pad(self, input_feature, height, width):
|
||||
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
||||
if should_pad:
|
||||
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
||||
input_feature = nn.functional.pad(input_feature, pad_values)
|
||||
|
||||
return input_feature
|
||||
|
||||
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
|
||||
height, width = input_dimensions
|
||||
# `dim` is height * width
|
||||
batch_size, dim, num_channels = input_feature.shape
|
||||
|
||||
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
||||
# pad input to be disible by width and height, if needed
|
||||
input_feature = self.maybe_pad(input_feature, height, width)
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
||||
# batch_size height/2 width/2 4*num_channels
|
||||
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
|
||||
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
|
||||
|
||||
input_feature = self.norm(input_feature)
|
||||
input_feature = self.reduction(input_feature)
|
||||
|
||||
return input_feature
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinDropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinSelfAttention(nn.Module):
|
||||
def __init__(self, config, dim, num_heads, window_size):
|
||||
super().__init__()
|
||||
if dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = num_heads
|
||||
self.attention_head_size = int(dim / num_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.window_size = (
|
||||
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||
)
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||
)
|
||||
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
|
||||
mask_shape = attention_mask.shape[0]
|
||||
attention_scores = attention_scores.view(
|
||||
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
|
||||
)
|
||||
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinSelfOutput(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinAttention(nn.Module):
|
||||
def __init__(self, config, dim, num_heads, window_size):
|
||||
super().__init__()
|
||||
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
|
||||
self.output = MaskFormerSwinSelfOutput(config, dim)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinIntermediate(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
|
||||
class MaskFormerSwinOutput(nn.Module):
|
||||
def __init__(self, config, dim):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MaskFormerSwinLayer(nn.Module):
|
||||
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
|
||||
super().__init__()
|
||||
self.shift_size = shift_size
|
||||
self.window_size = config.window_size
|
||||
self.input_resolution = input_resolution
|
||||
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
|
||||
self.drop_path = (
|
||||
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.intermediate = MaskFormerSwinIntermediate(config, dim)
|
||||
self.output = MaskFormerSwinOutput(config, dim)
|
||||
|
||||
def get_attn_mask(self, input_resolution):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
height, width = input_resolution
|
||||
img_mask = torch.zeros((1, height, width, 1))
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
width_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
count = 0
|
||||
for height_slice in height_slices:
|
||||
for width_slice in width_slices:
|
||||
img_mask[:, height_slice, width_slice, :] = count
|
||||
count += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
def maybe_pad(self, hidden_states, height, width):
|
||||
pad_left = pad_top = 0
|
||||
pad_rigth = (self.window_size - width % self.window_size) % self.window_size
|
||||
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
||||
pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom)
|
||||
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
||||
return hidden_states, pad_values
|
||||
|
||||
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
|
||||
height, width = input_dimensions
|
||||
batch_size, dim, channels = hidden_states.size()
|
||||
shortcut = hidden_states
|
||||
|
||||
hidden_states = self.layernorm_before(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||||
# pad hidden_states to multiples of window size
|
||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||
|
||||
_, height_pad, width_pad, _ = hidden_states.shape
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_hidden_states = hidden_states
|
||||
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask((height_pad, width_pad))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
||||
shifted_windows = window_reverse(
|
||||
attention_windows, self.window_size, height_pad, width_pad
|
||||
) # B height' width' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
attention_windows = shifted_windows
|
||||
|
||||
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||||
if was_padded:
|
||||
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
||||
|
||||
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
||||
|
||||
hidden_states = shortcut + self.drop_path(attention_windows)
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = hidden_states + self.output(layer_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MaskFormerSwinStage(nn.Module):
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
|
||||
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MaskFormerSwinLayer(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.pointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
height, width = input_dimensions
|
||||
for i, block_module in enumerate(self.blocks):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = block_hidden_states[0]
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(hidden_states, input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
return hidden_states, output_dimensions, all_hidden_states
|
||||
|
||||
|
||||
class MaskFormerSwinEncoder(nn.Module):
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
|
||||
def __init__(self, config, grid_size):
|
||||
super().__init__()
|
||||
self.num_layers = len(config.depths)
|
||||
self.config = config
|
||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MaskFormerSwinStage(
|
||||
config=config,
|
||||
dim=int(config.embed_dim * 2**i_layer),
|
||||
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
|
||||
depth=config.depths[i_layer],
|
||||
num_heads=config.num_heads[i_layer],
|
||||
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
|
||||
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
)
|
||||
for i_layer in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_input_dimensions = ()
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module), hidden_states, layer_head_mask
|
||||
)
|
||||
else:
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
)
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (layer_all_hidden_states,)
|
||||
|
||||
hidden_states = layer_hidden_states
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return MaskFormerSwinBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
hidden_states_spatial_dimensions=all_input_dimensions,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->MaskFormerSwin, swin->model
|
||||
class MaskFormerSwinPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = MaskFormerSwinConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, MaskFormerSwinEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||||
|
||||
self.embeddings = MaskFormerSwinEmbeddings(config)
|
||||
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
|
||||
pooled_output = None
|
||||
if self.pooler is not None:
|
||||
pooled_output = self.pooler(sequence_output.transpose(1, 2))
|
||||
pooled_output = torch.flatten(pooled_output, 1)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
|
||||
|
||||
return MaskFormerSwinModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel):
|
||||
"""
|
||||
MaskFormerSwin backbone, designed especially for the MaskFormer framework.
|
||||
|
||||
This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size,
|
||||
num_channels, height, width)`). It also adds additional layernorms after each stage.
|
||||
|
||||
Args:
|
||||
config (`MaskFormerSwinConfig`):
|
||||
The configuration used by [`MaskFormerSwinModel`].
|
||||
"""
|
||||
|
||||
def __init__(self, config: MaskFormerSwinConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.model = MaskFormerSwinModel(config)
|
||||
|
||||
self.out_features = config.out_features
|
||||
if "stem" in self.out_features:
|
||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||
|
||||
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
self.out_feature_channels = {}
|
||||
for i, stage in enumerate(self.stage_names[1:]):
|
||||
self.out_feature_channels[stage] = num_features[i]
|
||||
|
||||
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> BackboneOutput:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
outputs = self.model(
|
||||
pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
|
||||
)
|
||||
|
||||
# we skip the stem
|
||||
hidden_states = outputs.hidden_states[1:]
|
||||
|
||||
feature_maps = ()
|
||||
# we need to reshape the hidden states to their original spatial dimensions
|
||||
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
|
||||
spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions
|
||||
for i, (hidden_state, stage, (height, width)) in enumerate(
|
||||
zip(hidden_states, self.stage_names[1:], spatial_dimensions)
|
||||
):
|
||||
norm = self.hidden_states_norms[i]
|
||||
# the last element corespond to the layer's last block output but before patch merging
|
||||
hidden_state_unpolled = hidden_state[-1]
|
||||
hidden_state_norm = norm(hidden_state_unpolled)
|
||||
# the pixel decoder (FPN) expects 3D tensors (features)
|
||||
batch_size, _, hidden_size = hidden_state_norm.shape
|
||||
# reshape "b (h w) d -> b d h w"
|
||||
hidden_state_permuted = (
|
||||
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
|
||||
)
|
||||
if stage in self.out_features:
|
||||
feature_maps += (hidden_state_permuted,)
|
||||
|
||||
if not return_dict:
|
||||
output = (feature_maps,)
|
||||
if output_hidden_states:
|
||||
output += (outputs.hidden_states,)
|
||||
if output_attentions:
|
||||
output += (outputs.attentions,)
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -3361,6 +3361,13 @@ class MaskFormerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MaskFormerSwinBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MBartForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
386
tests/models/maskformer/test_modeling_maskformer_swin.py
Normal file
386
tests/models/maskformer/test_modeling_maskformer_swin.py
Normal file
@ -0,0 +1,386 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch MaskFormer Swin model. """
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from transformers import MaskFormerSwinConfig
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import MaskFormerSwinBackbone
|
||||
from transformers.models.maskformer import MaskFormerSwinModel
|
||||
|
||||
|
||||
class MaskFormerSwinModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=32,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
embed_dim=16,
|
||||
depths=[1, 2, 1],
|
||||
num_heads=[2, 2, 4],
|
||||
window_size=2,
|
||||
mlp_ratio=2.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
drop_path_rate=0.1,
|
||||
hidden_act="gelu",
|
||||
use_absolute_embeddings=False,
|
||||
patch_norm=True,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
is_training=True,
|
||||
scope=None,
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
out_features=["stage1", "stage2", "stage3"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.depths = depths
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.hidden_act = hidden_act
|
||||
self.use_absolute_embeddings = use_absolute_embeddings
|
||||
self.patch_norm = patch_norm
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.is_training = is_training
|
||||
self.scope = scope
|
||||
self.use_labels = use_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
self.out_features = out_features
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return MaskFormerSwinConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
embed_dim=self.embed_dim,
|
||||
depths=self.depths,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=self.qkv_bias,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
drop_path_rate=self.drop_path_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
use_absolute_embeddings=self.use_absolute_embeddings,
|
||||
path_norm=self.patch_norm,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = MaskFormerSwinModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
|
||||
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = MaskFormerSwinBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [13, 16, 16, 16])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, [16, 32, 64])
|
||||
|
||||
# verify ValueError
|
||||
with self.parent.assertRaises(ValueError):
|
||||
config.out_features = ["stem"]
|
||||
model = MaskFormerSwinBackbone(config=config)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MaskFormerSwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
MaskFormerSwinModel,
|
||||
MaskFormerSwinBackbone,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = False
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MaskFormerSwinModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MaskFormerSwinConfig, embed_dim=37)
|
||||
|
||||
def test_config(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
return
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Swin does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Swin does not support feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@unittest.skip(reason="MaskFormerSwin is only used as backbone and doesn't support output_attentions")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MaskFormerSwin is only used as an internal backbone")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# Swin has a different seq_length
|
||||
patch_size = (
|
||||
config.patch_size
|
||||
if isinstance(config.patch_size, collections.abc.Iterable)
|
||||
else (config.patch_size, config.patch_size)
|
||||
)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
image_size = (
|
||||
self.model_tester.image_size
|
||||
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
|
||||
else (self.model_tester.image_size, self.model_tester.image_size)
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
def test_hidden_states_output_with_padding(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.patch_size = 3
|
||||
|
||||
image_size = (
|
||||
self.model_tester.image_size
|
||||
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
|
||||
else (self.model_tester.image_size, self.model_tester.image_size)
|
||||
)
|
||||
patch_size = (
|
||||
config.patch_size
|
||||
if isinstance(config.patch_size, collections.abc.Iterable)
|
||||
else (config.patch_size, config.patch_size)
|
||||
)
|
||||
|
||||
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
||||
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
@unittest.skip(reason="MaskFormerSwin doesn't have pretrained checkpoints")
|
||||
def test_model_from_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin")
|
||||
def test_gradient_checkpointing_backward_compatibility(self):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
@ -492,8 +492,9 @@ SPECIAL_MODEL_NAMES = {
|
||||
"Data2VecAudio": "Data2Vec",
|
||||
"Data2VecText": "Data2Vec",
|
||||
"Data2VecVision": "Data2Vec",
|
||||
"DonutSwin": "Donut",
|
||||
"DonutSwin": "Swin Transformer",
|
||||
"Marian": "MarianMT",
|
||||
"MaskFormerSwin": "Swin Transformer",
|
||||
"OpenAI GPT-2": "GPT-2",
|
||||
"OpenAI GPT": "GPT",
|
||||
"Perceiver": "Perceiver IO",
|
||||
|
@ -41,6 +41,8 @@ PRIVATE_MODELS = [
|
||||
"T5Stack",
|
||||
"SwitchTransformersStack",
|
||||
"TFDPRSpanPredictor",
|
||||
"MaskFormerSwinModel",
|
||||
"MaskFormerSwinPreTrainedModel",
|
||||
]
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
@ -668,8 +670,11 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
"PyTorchBenchmarkArguments",
|
||||
"TensorFlowBenchmark",
|
||||
"TensorFlowBenchmarkArguments",
|
||||
"MaskFormerSwinBackbone",
|
||||
"ResNetBackbone",
|
||||
"AutoBackbone",
|
||||
"MaskFormerSwinConfig",
|
||||
"MaskFormerSwinModel",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user