Swin support for any input size (#15986)

* padding done

* correctly return one attention per layer

* almost correct, attentions are not flatten one tuple per stage

* tests green

* doc

* conversations

* reshaping hidden_states

* view in the test

* reshape_hidden_states in Encoder and Model

* new outputs with reshaped_hidden_states

* conversations

* doc

* Update docs/source/model_doc/swin.mdx

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* conversations

* fix tests

* minor changes

* resolved conversations

* attentions one per stage

* typo

* typos

* typos

* function signature

* CI

* clean up tests

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Francesco Saverio Zuppichini 2022-03-16 18:38:25 +01:00 committed by GitHub
parent 204c54d411
commit 667b823b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 317 additions and 118 deletions

View File

@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo
Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/>

View File

@ -17,6 +17,8 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
@ -25,12 +27,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_swin import SwinConfig
@ -56,10 +58,150 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Swin models at https://huggingface.co/models?filter=swin
]
# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
@dataclass
class SwinEncoderOutput(ModelOutput):
"""
Swin encoder's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinModelOutput(ModelOutput):
"""
Swin model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Average pooling of the last layer hidden-state.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinMaskedImageModelingOutput(ModelOutput):
"""
Swin masked image model outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinImageClassifierOutput(ModelOutput):
"""
Swin outputs for image classification.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
@ -130,7 +272,7 @@ class SwinEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
embeddings = self.patch_embeddings(pixel_values)
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
@ -145,7 +287,7 @@ class SwinEmbeddings(nn.Module):
embeddings = self.dropout(embeddings)
return embeddings
return embeddings, output_dimensions
class SwinPatchEmbeddings(nn.Module):
@ -165,9 +307,25 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values):
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
_, _, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, output_dimensions
class SwinPatchMerging(nn.Module):
@ -190,17 +348,30 @@ class SwinPatchMerging(nn.Module):
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, input_feature):
height, width = self.input_resolution
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature, input_dimensions):
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
input_feature_0 = input_feature[:, 0::2, 0::2, :] # batch_size height/2 width/2 num_channels
input_feature_1 = input_feature[:, 1::2, 0::2, :] # batch_size height/2 width/2 num_channels
input_feature_2 = input_feature[:, 0::2, 1::2, :] # batch_size height/2 width/2 num_channels
input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 num_channels
# pad input to be disible by width and height, if needed
input_feature = self.maybe_pad(input_feature, height, width)
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = input_feature[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
@ -393,19 +564,14 @@ class SwinOutput(nn.Module):
return hidden_states
class SwinBlock(nn.Module):
class SwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
@ -413,9 +579,15 @@ class SwinBlock(nn.Module):
self.intermediate = SwinIntermediate(config, dim)
self.output = SwinOutput(config, dim)
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
height, width = self.input_resolution
img_mask = torch.zeros((1, height, width, 1))
height_slices = (
slice(0, -self.window_size),
@ -439,17 +611,27 @@ class SwinBlock(nn.Module):
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
self.attn_mask = attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, head_mask=None, output_attentions=False):
height, width = self.input_resolution
batch_size, dim, channels = hidden_states.size()
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
@ -459,23 +641,18 @@ class SwinBlock(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
if self.attn_mask is not None:
self.attn_mask = self.attn_mask.to(hidden_states_windows.device)
self_attention_outputs = self.attention(
hidden_states_windows,
self.attn_mask,
head_mask,
output_attentions=output_attentions,
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_output = attention_outputs[0]
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
# reverse cyclic shift
if self.shift_size > 0:
@ -483,6 +660,10 @@ class SwinBlock(nn.Module):
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows)
@ -491,19 +672,18 @@ class SwinBlock(nn.Module):
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs
return outputs
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
class SwinLayer(nn.Module):
class SwinStage(nn.Module):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
SwinBlock(
SwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
@ -522,29 +702,28 @@ class SwinLayer(nn.Module):
self.pointing = False
def forward(self, hidden_states, head_mask=None, output_attentions=False, output_hidden_states=False):
all_hidden_states = () if output_hidden_states else None
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = block_module(
hidden_states,
layer_head_mask,
output_attentions,
)
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if self.downsample is not None:
layer_outputs_list = list(layer_outputs)
layer_outputs_list[0] = self.downsample(layer_outputs[0])
layer_outputs = tuple(layer_outputs_list)
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
else:
output_dimensions = (height, width, height, width)
return layer_outputs
stage_outputs = (hidden_states, output_dimensions)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
class SwinEncoder(nn.Module):
@ -555,7 +734,7 @@ class SwinEncoder(nn.Module):
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
SwinLayer(
SwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
@ -573,18 +752,26 @@ class SwinEncoder(nn.Module):
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
@ -596,23 +783,36 @@ class SwinEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
output_dimensions = layer_outputs[1]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions:
all_self_attentions += layer_outputs[2:]
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
return SwinEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
@ -712,7 +912,7 @@ class SwinModel(SwinPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
output_type=SwinModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -742,10 +942,11 @@ class SwinModel(SwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@ -761,13 +962,16 @@ class SwinModel(SwinPreTrainedModel):
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
output = (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
return output
return SwinModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
@ -791,7 +995,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
@ -869,11 +1073,12 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[2:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput(
return SwinMaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
@ -903,7 +1108,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
output_type=SwinImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
@ -963,9 +1168,10 @@ class SwinForImageClassification(SwinPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return SwinImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)

View File

@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths))
attentions = outputs.attentions
expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
out_len = len(outputs)
# Check attention is always last and order is fine
@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), len(self.model_tester.depths))
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
self.assertEqual(len(self_attentions), expected_num_attentions)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
# Swin has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))