mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Swin support for any input size (#15986)
* padding done * correctly return one attention per layer * almost correct, attentions are not flatten one tuple per stage * tests green * doc * conversations * reshaping hidden_states * view in the test * reshape_hidden_states in Encoder and Model * new outputs with reshaped_hidden_states * conversations * doc * Update docs/source/model_doc/swin.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * conversations * fix tests * minor changes * resolved conversations * attentions one per stage * typo * typos * typos * function signature * CI * clean up tests Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
204c54d411
commit
667b823b89
@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo
|
||||
|
||||
Tips:
|
||||
- One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
|
||||
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
|
||||
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,12 +27,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_swin import SwinConfig
|
||||
@ -56,10 +58,150 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Swin models at https://huggingface.co/models?filter=swin
|
||||
]
|
||||
|
||||
|
||||
# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwinEncoderOutput(ModelOutput):
|
||||
"""
|
||||
Swin encoder's outputs, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, hidden_size, height, width)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||
include the spatial dimensions.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwinModelOutput(ModelOutput):
|
||||
"""
|
||||
Swin model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||
Average pooling of the last layer hidden-state.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, hidden_size, height, width)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||
include the spatial dimensions.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwinMaskedImageModelingOutput(ModelOutput):
|
||||
"""
|
||||
Swin masked image model outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
|
||||
Masked image modeling (MLM) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Reconstructed pixel values.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, hidden_size, height, width)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||
include the spatial dimensions.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwinImageClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Swin outputs for image classification.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, hidden_size, height, width)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||
include the spatial dimensions.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
@ -130,7 +272,7 @@ class SwinEmbeddings(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None):
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||||
embeddings = self.norm(embeddings)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
@ -145,7 +287,7 @@ class SwinEmbeddings(nn.Module):
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
return embeddings, output_dimensions
|
||||
|
||||
|
||||
class SwinPatchEmbeddings(nn.Module):
|
||||
@ -165,9 +307,25 @@ class SwinPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def maybe_pad(self, pixel_values, height, width):
|
||||
if width % self.patch_size[1] != 0:
|
||||
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
if height % self.patch_size[0] != 0:
|
||||
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||
return pixel_values
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
return embeddings
|
||||
_, _, height, width = pixel_values.shape
|
||||
# pad the input to be divisible by self.patch_size, if needed
|
||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||
embeddings = self.projection(pixel_values)
|
||||
_, _, height, width = embeddings.shape
|
||||
output_dimensions = (height, width)
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, output_dimensions
|
||||
|
||||
|
||||
class SwinPatchMerging(nn.Module):
|
||||
@ -190,17 +348,30 @@ class SwinPatchMerging(nn.Module):
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, input_feature):
|
||||
height, width = self.input_resolution
|
||||
def maybe_pad(self, input_feature, height, width):
|
||||
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
||||
if should_pad:
|
||||
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
||||
input_feature = nn.functional.pad(input_feature, pad_values)
|
||||
|
||||
return input_feature
|
||||
|
||||
def forward(self, input_feature, input_dimensions):
|
||||
height, width = input_dimensions
|
||||
# `dim` is height * width
|
||||
batch_size, dim, num_channels = input_feature.shape
|
||||
|
||||
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
||||
|
||||
input_feature_0 = input_feature[:, 0::2, 0::2, :] # batch_size height/2 width/2 num_channels
|
||||
input_feature_1 = input_feature[:, 1::2, 0::2, :] # batch_size height/2 width/2 num_channels
|
||||
input_feature_2 = input_feature[:, 0::2, 1::2, :] # batch_size height/2 width/2 num_channels
|
||||
input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 num_channels
|
||||
# pad input to be disible by width and height, if needed
|
||||
input_feature = self.maybe_pad(input_feature, height, width)
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
||||
# batch_size height/2 width/2 4*num_channels
|
||||
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
|
||||
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
|
||||
@ -393,19 +564,14 @@ class SwinOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinBlock(nn.Module):
|
||||
class SwinLayer(nn.Module):
|
||||
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.shift_size = shift_size
|
||||
self.window_size = config.window_size
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
|
||||
self.set_shift_and_window_size(input_resolution)
|
||||
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.attention = SwinAttention(config, dim, num_heads)
|
||||
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
@ -413,9 +579,15 @@ class SwinBlock(nn.Module):
|
||||
self.intermediate = SwinIntermediate(config, dim)
|
||||
self.output = SwinOutput(config, dim)
|
||||
|
||||
def set_shift_and_window_size(self, input_resolution):
|
||||
if min(input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
|
||||
def get_attn_mask(self, height, width):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
height, width = self.input_resolution
|
||||
img_mask = torch.zeros((1, height, width, 1))
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
@ -439,17 +611,27 @@ class SwinBlock(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
self.attn_mask = attn_mask
|
||||
def maybe_pad(self, hidden_states, height, width):
|
||||
pad_right = (self.window_size - width % self.window_size) % self.window_size
|
||||
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
||||
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
|
||||
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
||||
return hidden_states, pad_values
|
||||
|
||||
def forward(self, hidden_states, head_mask=None, output_attentions=False):
|
||||
height, width = self.input_resolution
|
||||
batch_size, dim, channels = hidden_states.size()
|
||||
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
|
||||
self.set_shift_and_window_size(input_dimensions)
|
||||
height, width = input_dimensions
|
||||
batch_size, _, channels = hidden_states.size()
|
||||
shortcut = hidden_states
|
||||
|
||||
hidden_states = self.layernorm_before(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||||
# pad hidden_states to multiples of window size
|
||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||
|
||||
_, height_pad, width_pad, _ = hidden_states.shape
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
@ -459,23 +641,18 @@ class SwinBlock(nn.Module):
|
||||
# partition windows
|
||||
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||||
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||||
|
||||
if self.attn_mask is not None:
|
||||
self.attn_mask = self.attn_mask.to(hidden_states_windows.device)
|
||||
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states_windows,
|
||||
self.attn_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
attention_outputs = self.attention(
|
||||
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
||||
shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C
|
||||
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
@ -483,6 +660,10 @@ class SwinBlock(nn.Module):
|
||||
else:
|
||||
attention_windows = shifted_windows
|
||||
|
||||
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||||
if was_padded:
|
||||
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
||||
|
||||
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
||||
|
||||
hidden_states = shortcut + self.drop_path(attention_windows)
|
||||
@ -491,19 +672,18 @@ class SwinBlock(nn.Module):
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = hidden_states + self.output(layer_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
|
||||
return layer_outputs
|
||||
|
||||
|
||||
class SwinLayer(nn.Module):
|
||||
class SwinStage(nn.Module):
|
||||
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinBlock(
|
||||
SwinLayer(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
@ -522,29 +702,28 @@ class SwinLayer(nn.Module):
|
||||
|
||||
self.pointing = False
|
||||
|
||||
def forward(self, hidden_states, head_mask=None, output_attentions=False, output_hidden_states=False):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for i, block_module in enumerate(self.blocks):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
|
||||
height, width = input_dimensions
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = block_module(
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
)
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.downsample is not None:
|
||||
layer_outputs_list = list(layer_outputs)
|
||||
layer_outputs_list[0] = self.downsample(layer_outputs[0])
|
||||
layer_outputs = tuple(layer_outputs_list)
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
return layer_outputs
|
||||
stage_outputs = (hidden_states, output_dimensions)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
return stage_outputs
|
||||
|
||||
|
||||
class SwinEncoder(nn.Module):
|
||||
@ -555,7 +734,7 @@ class SwinEncoder(nn.Module):
|
||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SwinLayer(
|
||||
SwinStage(
|
||||
config=config,
|
||||
dim=int(config.embed_dim * 2**i_layer),
|
||||
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
|
||||
@ -573,18 +752,26 @@ class SwinEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
all_input_dimensions = ()
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_reshaped_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if output_hidden_states:
|
||||
batch_size, _, hidden_size = hidden_states.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||||
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||||
all_hidden_states += (hidden_states,)
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@ -596,23 +783,36 @@ class SwinEncoder(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module), hidden_states, layer_head_mask
|
||||
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
output_dimensions = layer_outputs[1]
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
|
||||
if output_hidden_states:
|
||||
batch_size, _, hidden_size = hidden_states.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||||
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||||
all_hidden_states += (hidden_states,)
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions += layer_outputs[2:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
||||
return SwinEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
reshaped_hidden_states=all_reshaped_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@ -712,7 +912,7 @@ class SwinModel(SwinPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
output_type=SwinModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
@ -742,10 +942,11 @@ class SwinModel(SwinPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -761,13 +962,16 @@ class SwinModel(SwinPreTrainedModel):
|
||||
pooled_output = torch.flatten(pooled_output, 1)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
output = (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
return output
|
||||
|
||||
return SwinModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@ -791,7 +995,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
@ -869,11 +1073,12 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||||
output = (reconstructed_pixel_values,) + outputs[2:]
|
||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
return SwinMaskedImageModelingOutput(
|
||||
loss=masked_im_loss,
|
||||
logits=reconstructed_pixel_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@ -903,7 +1108,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
output_type=SwinImageClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
@ -963,9 +1168,10 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
return SwinImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||
)
|
||||
|
@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), len(self.model_tester.depths))
|
||||
attentions = outputs.attentions
|
||||
expected_num_attentions = len(self.model_tester.depths)
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), len(self.model_tester.depths))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
# also another +1 for reshaped_hidden_states
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), len(self.model_tester.depths))
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# Swin has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user