[Maskformer] Add MaskFormerSwin backbone (#20344)

* First draft

* Fix backwards compatibility

* More fixes

* More fixes

* Make backbone more general

* Improve backbone

* Improve test

* Fix config checkpoint

* Address comments

* Use model_type

* Address more comments

* Fix special model names

* Remove MaskFormerSwinModel and MaskFormerSwinPreTrainedModel from main init

* Fix typo

* Update backbone

* Apply suggestion

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-11-28 20:33:49 +01:00 committed by GitHub
parent 955780d3ab
commit 6dc884abc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1525 additions and 843 deletions

View File

@ -288,6 +288,7 @@ Flax), PyTorch, and/or TensorFlow.
| Marian | ✅ | ❌ | ✅ | ✅ | ✅ |
| MarkupLM | ✅ | ✅ | ✅ | ❌ | ❌ |
| MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| MaskFormerSwin | ❌ | ❌ | ❌ | ❌ | ❌ |
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
| Megatron-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |

View File

@ -295,7 +295,7 @@ _import_structure = {
"MarkupLMProcessor",
"MarkupLMTokenizer",
],
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"],
@ -1622,6 +1622,7 @@ else:
"MaskFormerForInstanceSegmentation",
"MaskFormerModel",
"MaskFormerPreTrainedModel",
"MaskFormerSwinBackbone",
]
)
_import_structure["models.markuplm"].extend(
@ -3479,7 +3480,7 @@ if TYPE_CHECKING:
MarkupLMProcessor,
MarkupLMTokenizer,
)
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig
from .models.mbart import MBartConfig
from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
@ -4573,6 +4574,7 @@ if TYPE_CHECKING:
MaskFormerForInstanceSegmentation,
MaskFormerModel,
MaskFormerPreTrainedModel,
MaskFormerSwinBackbone,
)
from .models.mbart import (
MBartForCausalLM,

View File

@ -98,6 +98,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("marian", "MarianConfig"),
("markuplm", "MarkupLMConfig"),
("maskformer", "MaskFormerConfig"),
("maskformer-swin", "MaskFormerSwinConfig"),
("mbart", "MBartConfig"),
("mctct", "MCTCTConfig"),
("megatron-bert", "MegatronBertConfig"),
@ -395,6 +396,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("marian", "Marian"),
("markuplm", "MarkupLM"),
("maskformer", "MaskFormer"),
("maskformer-swin", "MaskFormerSwin"),
("mbart", "mBART"),
("mbart50", "mBART-50"),
("mctct", "M-CTC-T"),
@ -491,6 +493,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("data2vec-text", "data2vec"),
("data2vec-vision", "data2vec"),
("donut-swin", "donut"),
("maskformer-swin", "maskformer"),
("xclip", "x_clip"),
]
)

View File

@ -97,6 +97,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("marian", "MarianModel"),
("markuplm", "MarkupLMModel"),
("maskformer", "MaskFormerModel"),
("maskformer-swin", "MaskFormerSwinModel"),
("mbart", "MBartModel"),
("mctct", "MCTCTModel"),
("megatron-bert", "MegatronBertModel"),
@ -847,6 +848,7 @@ _MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[
# Backbone mapping
("maskformer-swin", "MaskFormerSwinBackbone"),
("resnet", "ResNetBackbone"),
]
)

View File

@ -20,7 +20,10 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"]}
_import_structure = {
"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"configuration_maskformer_swin": ["MaskFormerSwinConfig"],
}
try:
if not is_vision_available():
@ -43,9 +46,15 @@ else:
"MaskFormerModel",
"MaskFormerPreTrainedModel",
]
_import_structure["modeling_maskformer_swin"] = [
"MaskFormerSwinBackbone",
"MaskFormerSwinModel",
"MaskFormerSwinPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
try:
if not is_vision_available():
@ -66,6 +75,11 @@ if TYPE_CHECKING:
MaskFormerModel,
MaskFormerPreTrainedModel,
)
from .modeling_maskformer_swin import (
MaskFormerSwinBackbone,
MaskFormerSwinModel,
MaskFormerSwinPreTrainedModel,
)
else:

View File

@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MaskFormer Swin Transformer model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class MaskFormerSwinConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Swin
[microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 4):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
embed_dim (`int`, *optional*, defaults to 96):
Dimensionality of patch embedding.
depths (`List[int]`, *optional*, defaults to `[2, 2, 6, 2]`):
Depth of each layer in the Transformer encoder.
num_heads (`List[int]`, *optional*, defaults to `[3, 6, 12, 24]`):
Number of attention heads in each layer of the Transformer encoder.
window_size (`int`, *optional*, defaults to 7):
Size of windows.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of MLP hidden dimensionality to embedding dimensionality.
qkv_bias (`bool`, *optional*, defaults to True):
Whether or not a learnable bias should be added to the queries, keys and values.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings and encoder.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
drop_path_rate (`float`, *optional*, defaults to 0.1):
Stochastic depth rate.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
`"selu"` and `"gelu_new"` are supported.
use_absolute_embeddings (`bool`, *optional*, defaults to False):
Whether or not to add absolute position embeddings to the patch embeddings.
patch_norm (`bool`, *optional*, defaults to True):
Whether or not to add layer normalization after patch embedding.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
Example:
```python
>>> from transformers import MaskFormerSwinConfig, MaskFormerSwinModel
>>> # Initializing a microsoft/swin-tiny-patch4-window7-224 style configuration
>>> configuration = MaskFormerSwinConfig()
>>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration
>>> model = MaskFormerSwinModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "maskformer-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.path_norm = patch_norm
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features

View File

@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch MaskFormer model."""
import collections.abc
import math
import random
from dataclasses import dataclass
@ -25,12 +24,12 @@ import numpy as np
import torch
from torch import Tensor, nn
from transformers import AutoBackbone
from transformers.utils import logging
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
@ -40,8 +39,8 @@ from ...utils import (
requires_backends,
)
from ..detr import DetrConfig
from ..swin import SwinConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
if is_scipy_available():
@ -60,71 +59,6 @@ MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
"""
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a mean pooling operation.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
`forward` method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerSwinBaseModelOutput(ModelOutput):
"""
Class for SwinEncoder's outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
@ -471,714 +405,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
return loss / height_and_width
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows
# Copied from transformers.models.swin.modeling_swin.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
class MaskFormerSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings.
"""
def __init__(self, config):
super().__init__()
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class MaskFormerSwinPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding, including padding.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values):
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings_flat = embeddings.flatten(2).transpose(1, 2)
return embeddings_flat, output_dimensions
class MaskFormerSwinPatchMerging(nn.Module):
"""
Patch Merging Layer for maskformer model.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def maybe_pad(self, input_feature, width, height):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature, input_dimensions):
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
# pad input to be disible by width and height, if needed
input_feature = self.maybe_pad(input_feature, height, width)
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = input_feature[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
class MaskFormerSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
class MaskFormerSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
self.output = MaskFormerSwinSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
class MaskFormerSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
class MaskFormerSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MaskFormerSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
self.drop_path = (
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = MaskFormerSwinIntermediate(config, dim)
self.output = MaskFormerSwinOutput(config, dim)
def get_attn_mask(self, input_resolution):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
height, width = input_resolution
img_mask = torch.zeros((1, height, width, 1))
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_left = pad_top = 0
pad_rigth = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
height, width = input_dimensions
batch_size, dim, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask((height_pad, width_pad))
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
self_attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs
return outputs
class MaskFormerSwinStage(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
MaskFormerSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
):
all_hidden_states = () if output_hidden_states else None
height, width = input_dimensions
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = block_hidden_states[0]
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(hidden_states, input_dimensions)
else:
output_dimensions = (height, width, height, width)
return hidden_states, output_dimensions, all_hidden_states
class MaskFormerSwinEncoder(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
MaskFormerSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_input_dimensions = ()
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
)
else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
output_hidden_states,
)
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
all_hidden_states += (layer_all_hidden_states,)
hidden_states = layer_hidden_states
if output_attentions:
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return MaskFormerSwinBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
hidden_states_spatial_dimensions=all_input_dimensions,
attentions=all_self_attentions,
)
class MaskFormerSwinModel(nn.Module, ModuleUtilsMixin):
def __init__(self, config, add_pooling_layer=True):
super().__init__()
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = MaskFormerSwinEmbeddings(config)
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs.last_hidden_state
sequence_output = self.layernorm(sequence_output)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
return MaskFormerSwinModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
attentions=encoder_outputs.attentions,
)
# Copied from transformers.models.detr.modeling_detr.DetrAttention
class DetrAttention(nn.Module):
"""
@ -1923,52 +1149,6 @@ class MaskFormerLoss(nn.Module):
return num_masks_pt
class MaskFormerSwinTransformerBackbone(nn.Module):
"""
This class uses [`MaskFormerSwinModel`] to reshape its `hidden_states` from (`batch_size, sequence_length,
hidden_size)` to (`batch_size, num_channels, height, width)`).
Args:
config (`SwinConfig`):
The configuration used by [`MaskFormerSwinModel`].
"""
def __init__(self, config: SwinConfig):
super().__init__()
self.model = MaskFormerSwinModel(config)
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(out_shape) for out_shape in self.outputs_shapes])
def forward(self, *args, **kwargs) -> List[Tensor]:
output = self.model(*args, **kwargs, output_hidden_states=True)
hidden_states_permuted: List[Tensor] = []
# we need to reshape the hidden state to their original spatial dimensions
# skipping the embeddings
hidden_states: Tuple[Tuple[Tensor]] = output.hidden_states[1:]
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
spatial_dimensions: Tuple[Tuple[int, int]] = output.hidden_states_spatial_dimensions
for i, (hidden_state, (height, width)) in enumerate(zip(hidden_states, spatial_dimensions)):
norm = self.hidden_states_norms[i]
# the last element corespond to the layer's last block output but before patch merging
hidden_state_unpolled = hidden_state[-1]
hidden_state_norm = norm(hidden_state_unpolled)
# our pixel decoder (FPN) expect 3D tensors (features)
batch_size, _, hidden_size = hidden_state_norm.shape
# reshape our tensor "b (h w) d -> b d h w"
hidden_state_permuted = (
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
)
hidden_states_permuted.append(hidden_state_permuted)
return hidden_states_permuted
@property
def input_resolutions(self) -> List[int]:
return [layer.input_resolution for layer in self.model.encoder.layers]
@property
def outputs_shapes(self) -> List[int]:
return [layer.dim for layer in self.model.encoder.layers]
class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
@ -2065,7 +1245,7 @@ class MaskFormerPixelDecoder(nn.Module):
def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
"""
Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's feature into a Feature Pyramid
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's features into a Feature Pyramid
Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.
Args:
@ -2075,13 +1255,15 @@ class MaskFormerPixelDecoder(nn.Module):
The features (channels) of the target masks size \\C_{\epsilon}\\ in the paper.
"""
super().__init__()
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput:
fpn_features: List[Tensor] = self.fpn(features)
fpn_features = self.fpn(features)
# we use the last feature map
last_feature_projected = self.mask_projection(fpn_features[-1])
return MaskFormerPixelDecoderOutput(
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
)
@ -2193,17 +1375,26 @@ class MaskFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model.
"""
super().__init__()
self.encoder = MaskFormerSwinTransformerBackbone(config.backbone_config)
# TODD: add method to load pretrained weights of backbone
backbone_config = config.backbone_config
if backbone_config.model_type == "swin":
# for backwards compatibility
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
self.encoder = AutoBackbone.from_config(backbone_config)
feature_channels = self.encoder.channels
self.decoder = MaskFormerPixelDecoder(
in_features=self.encoder.outputs_shapes[-1],
in_features=feature_channels[-1],
feature_size=config.fpn_feature_size,
mask_feature_size=config.mask_feature_size,
lateral_widths=self.encoder.outputs_shapes[:-1],
lateral_widths=feature_channels[:-1],
)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput:
features: List[Tensor] = self.encoder(pixel_values)
decoder_output: MaskFormerPixelDecoderOutput = self.decoder(features, output_hidden_states)
features = self.encoder(pixel_values).feature_maps
decoder_output = self.decoder(features, output_hidden_states)
return MaskFormerPixelLevelModuleOutput(
# the last feature is actually the output from the last layer
encoder_last_hidden_state=features[-1],
@ -2335,8 +1526,8 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing = value
if isinstance(module, MaskFormerPixelLevelModule):
module.encoder.gradient_checkpointing = value
if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value
@ -2350,7 +1541,7 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
super().__init__(config)
self.pixel_level_module = MaskFormerPixelLevelModule(config)
self.transformer_module = MaskFormerTransformerModule(
in_features=self.pixel_level_module.encoder.outputs_shapes[-1], config=config
in_features=self.pixel_level_module.encoder.channels[-1], config=config
)
self.post_init()
@ -2407,15 +1598,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output: MaskFormerPixelLevelModuleOutput = self.pixel_level_module(
pixel_values, output_hidden_states
)
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states)
image_features = pixel_level_module_output.encoder_last_hidden_state
pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state
transformer_module_output: DetrDecoderOutput = self.transformer_module(
image_features, output_hidden_states, output_attentions
)
transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions)
queries = transformer_module_output.last_hidden_state
encoder_hidden_states = None

View File

@ -0,0 +1,925 @@
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden
states before downsampling, which is different from the default Swin Transformer."""
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig
@dataclass
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
"""
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a mean pooling operation.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
`forward` method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerSwinBaseModelOutput(ModelOutput):
"""
Class for SwinEncoder's outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows
# Copied from transformers.models.swin.modeling_swin.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
class MaskFormerSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings.
"""
def __init__(self, config):
super().__init__()
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
class MaskFormerSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class MaskFormerSwinPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
# pad input to be disible by width and height, if needed
input_feature = self.maybe_pad(input_feature, height, width)
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = input_feature[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
class MaskFormerSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
class MaskFormerSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
self.output = MaskFormerSwinSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
class MaskFormerSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
class MaskFormerSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MaskFormerSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
self.drop_path = (
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = MaskFormerSwinIntermediate(config, dim)
self.output = MaskFormerSwinOutput(config, dim)
def get_attn_mask(self, input_resolution):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
height, width = input_resolution
img_mask = torch.zeros((1, height, width, 1))
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_left = pad_top = 0
pad_rigth = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
height, width = input_dimensions
batch_size, dim, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask((height_pad, width_pad))
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
self_attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs
return outputs
class MaskFormerSwinStage(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
MaskFormerSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
):
all_hidden_states = () if output_hidden_states else None
height, width = input_dimensions
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = block_hidden_states[0]
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(hidden_states, input_dimensions)
else:
output_dimensions = (height, width, height, width)
return hidden_states, output_dimensions, all_hidden_states
class MaskFormerSwinEncoder(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
MaskFormerSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_input_dimensions = ()
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
)
else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
output_hidden_states,
)
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
all_hidden_states += (layer_all_hidden_states,)
hidden_states = layer_hidden_states
if output_attentions:
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return MaskFormerSwinBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
hidden_states_spatial_dimensions=all_input_dimensions,
attentions=all_self_attentions,
)
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->MaskFormerSwin, swin->model
class MaskFormerSwinPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MaskFormerSwinConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing = value
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = MaskFormerSwinEmbeddings(config)
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
return MaskFormerSwinModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
attentions=encoder_outputs.attentions,
)
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel):
"""
MaskFormerSwin backbone, designed especially for the MaskFormer framework.
This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size,
num_channels, height, width)`). It also adds additional layernorms after each stage.
Args:
config (`MaskFormerSwinConfig`):
The configuration used by [`MaskFormerSwinModel`].
"""
def __init__(self, config: MaskFormerSwinConfig):
super().__init__(config)
self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config)
self.out_features = config.out_features
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
outputs = self.model(
pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
)
# we skip the stem
hidden_states = outputs.hidden_states[1:]
feature_maps = ()
# we need to reshape the hidden states to their original spatial dimensions
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions
for i, (hidden_state, stage, (height, width)) in enumerate(
zip(hidden_states, self.stage_names[1:], spatial_dimensions)
):
norm = self.hidden_states_norms[i]
# the last element corespond to the layer's last block output but before patch merging
hidden_state_unpolled = hidden_state[-1]
hidden_state_norm = norm(hidden_state_unpolled)
# the pixel decoder (FPN) expects 3D tensors (features)
batch_size, _, hidden_size = hidden_state_norm.shape
# reshape "b (h w) d -> b d h w"
hidden_state_permuted = (
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
)
if stage in self.out_features:
feature_maps += (hidden_state_permuted,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
if output_attentions:
output += (outputs.attentions,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -3361,6 +3361,13 @@ class MaskFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MaskFormerSwinBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MBartForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -0,0 +1,386 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch MaskFormer Swin model. """
import collections
import inspect
import unittest
from typing import Dict, List, Tuple
from transformers import MaskFormerSwinConfig
from transformers.testing_utils import require_torch, torch_device
from transformers.utils import is_torch_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import MaskFormerSwinBackbone
from transformers.models.maskformer import MaskFormerSwinModel
class MaskFormerSwinModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
patch_size=2,
num_channels=3,
embed_dim=16,
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
mlp_ratio=2.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
is_training=True,
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2", "stage3"],
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.patch_norm = patch_norm
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.is_training = is_training
self.scope = scope
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return MaskFormerSwinConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
embed_dim=self.embed_dim,
depths=self.depths,
num_heads=self.num_heads,
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
drop_path_rate=self.drop_path_rate,
hidden_act=self.hidden_act,
use_absolute_embeddings=self.use_absolute_embeddings,
path_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
)
def create_and_check_model(self, config, pixel_values, labels):
model = MaskFormerSwinModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = MaskFormerSwinBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [13, 16, 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, [16, 32, 64])
# verify ValueError
with self.parent.assertRaises(ValueError):
config.out_features = ["stem"]
model = MaskFormerSwinBackbone(config=config)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class MaskFormerSwinModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
MaskFormerSwinModel,
MaskFormerSwinBackbone,
)
if is_torch_available()
else ()
)
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = MaskFormerSwinModelTester(self)
self.config_tester = ConfigTester(self, config_class=MaskFormerSwinConfig, embed_dim=37)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
@unittest.skip("Swin does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("Swin does not support feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@unittest.skip(reason="MaskFormerSwin is only used as backbone and doesn't support output_attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="MaskFormerSwin is only used as an internal backbone")
def test_save_load_fast_init_to_base(self):
pass
def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Swin has a different seq_length
patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = (
self.model_tester.image_size
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
else (self.model_tester.image_size, self.model_tester.image_size)
)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
def test_hidden_states_output_with_padding(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.patch_size = 3
image_size = (
self.model_tester.image_size
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
else (self.model_tester.image_size, self.model_tester.image_size)
)
patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
@unittest.skip(reason="MaskFormerSwin doesn't have pretrained checkpoints")
def test_model_from_pretrained(self):
pass
@unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin")
def test_initialization(self):
pass
@unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin")
def test_gradient_checkpointing_backward_compatibility(self):
pass
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

View File

@ -492,8 +492,9 @@ SPECIAL_MODEL_NAMES = {
"Data2VecAudio": "Data2Vec",
"Data2VecText": "Data2Vec",
"Data2VecVision": "Data2Vec",
"DonutSwin": "Donut",
"DonutSwin": "Swin Transformer",
"Marian": "MarianMT",
"MaskFormerSwin": "Swin Transformer",
"OpenAI GPT-2": "GPT-2",
"OpenAI GPT": "GPT",
"Perceiver": "Perceiver IO",

View File

@ -41,6 +41,8 @@ PRIVATE_MODELS = [
"T5Stack",
"SwitchTransformersStack",
"TFDPRSpanPredictor",
"MaskFormerSwinModel",
"MaskFormerSwinPreTrainedModel",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.
@ -668,8 +670,11 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"PyTorchBenchmarkArguments",
"TensorFlowBenchmark",
"TensorFlowBenchmarkArguments",
"MaskFormerSwinBackbone",
"ResNetBackbone",
"AutoBackbone",
"MaskFormerSwinConfig",
"MaskFormerSwinModel",
]