mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Swin backbone (#20769)
* Add Swin backbone * Remove line * Add code example Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
94f8e21c70
commit
67acb07e9e
@ -2078,6 +2078,7 @@ else:
|
||||
_import_structure["models.swin"].extend(
|
||||
[
|
||||
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"SwinBackbone",
|
||||
"SwinForImageClassification",
|
||||
"SwinForMaskedImageModeling",
|
||||
"SwinModel",
|
||||
@ -5041,6 +5042,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.swin import (
|
||||
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SwinBackbone,
|
||||
SwinForImageClassification,
|
||||
SwinForMaskedImageModeling,
|
||||
SwinModel,
|
||||
|
@ -869,6 +869,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("maskformer-swin", "MaskFormerSwinBackbone"),
|
||||
("nat", "NatBackbone"),
|
||||
("resnet", "ResNetBackbone"),
|
||||
("swin", "SwinBackbone"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -523,7 +523,6 @@ class DonutSwinLayer(nn.Module):
|
||||
self.shift_size = shift_size
|
||||
self.window_size = config.window_size
|
||||
self.input_resolution = input_resolution
|
||||
self.set_shift_and_window_size(input_resolution)
|
||||
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
|
||||
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
@ -585,7 +584,9 @@ class DonutSwinLayer(nn.Module):
|
||||
shortcut = hidden_states
|
||||
|
||||
hidden_states = self.layernorm_before(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||||
|
||||
# pad hidden_states to multiples of window size
|
||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||
|
||||
@ -677,14 +678,15 @@ class DonutSwinStage(nn.Module):
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
|
||||
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (hidden_states, output_dimensions)
|
||||
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
@ -722,9 +724,9 @@ class DonutSwinEncoder(nn.Module):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[Tuple, DonutSwinEncoderOutput]:
|
||||
all_input_dimensions = ()
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_reshaped_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@ -755,12 +757,22 @@ class DonutSwinEncoder(nn.Module):
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[1]
|
||||
hidden_states_before_downsampling = layer_outputs[1]
|
||||
output_dimensions = layer_outputs[2]
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
|
||||
if output_hidden_states:
|
||||
if output_hidden_states and output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
# here we use the original (not downsampled) height and width
|
||||
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
||||
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
||||
)
|
||||
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||||
all_hidden_states += (hidden_states_before_downsampling,)
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||||
@ -769,7 +781,7 @@ class DonutSwinEncoder(nn.Module):
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions += layer_outputs[2:]
|
||||
all_self_attentions += layer_outputs[3:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
@ -36,6 +36,7 @@ else:
|
||||
"SwinForMaskedImageModeling",
|
||||
"SwinModel",
|
||||
"SwinPreTrainedModel",
|
||||
"SwinBackbone",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -63,6 +64,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .modeling_swin import (
|
||||
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SwinBackbone,
|
||||
SwinForImageClassification,
|
||||
SwinForMaskedImageModeling,
|
||||
SwinModel,
|
||||
|
@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig):
|
||||
The epsilon used by the layer normalization layers.
|
||||
encoder_stride (`int`, `optional`, defaults to 32):
|
||||
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||
|
||||
Example:
|
||||
|
||||
@ -125,6 +128,7 @@ class SwinConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
encoder_stride=32,
|
||||
out_features=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -151,6 +155,16 @@ class SwinConfig(PretrainedConfig):
|
||||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
self.out_features = out_features
|
||||
|
||||
|
||||
class SwinOnnxConfig(OnnxConfig):
|
||||
|
@ -26,7 +26,8 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -589,7 +590,6 @@ class SwinLayer(nn.Module):
|
||||
self.shift_size = shift_size
|
||||
self.window_size = config.window_size
|
||||
self.input_resolution = input_resolution
|
||||
self.set_shift_and_window_size(input_resolution)
|
||||
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
|
||||
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
@ -651,7 +651,9 @@ class SwinLayer(nn.Module):
|
||||
shortcut = hidden_states
|
||||
|
||||
hidden_states = self.layernorm_before(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||||
|
||||
# pad hidden_states to multiples of window size
|
||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||
|
||||
@ -742,14 +744,15 @@ class SwinStage(nn.Module):
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
|
||||
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (hidden_states, output_dimensions)
|
||||
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
@ -786,9 +789,9 @@ class SwinEncoder(nn.Module):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[Tuple, SwinEncoderOutput]:
|
||||
all_input_dimensions = ()
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_reshaped_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@ -819,12 +822,22 @@ class SwinEncoder(nn.Module):
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[1]
|
||||
hidden_states_before_downsampling = layer_outputs[1]
|
||||
output_dimensions = layer_outputs[2]
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
|
||||
if output_hidden_states:
|
||||
if output_hidden_states and output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
# here we use the original (not downsampled) height and width
|
||||
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
||||
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
||||
)
|
||||
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||||
all_hidden_states += (hidden_states_before_downsampling,)
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||||
@ -833,7 +846,7 @@ class SwinEncoder(nn.Module):
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions += layer_outputs[2:]
|
||||
all_self_attentions += layer_outputs[3:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
@ -1214,3 +1227,118 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Swin backbone, to be used with frameworks like DETR and MaskFormer.
|
||||
""",
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config: SwinConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
|
||||
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
self.out_feature_channels = {}
|
||||
self.out_feature_channels["stem"] = config.embed_dim
|
||||
for i, stage in enumerate(self.stage_names[1:]):
|
||||
self.out_feature_channels[stage] = num_features[i]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = dict()
|
||||
for stage, num_channels in zip(self.out_features, self.channels):
|
||||
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||||
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> BackboneOutput:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
|
||||
>>> model = AutoBackbone.from_pretrained(
|
||||
... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||
... )
|
||||
|
||||
>>> inputs = processor(image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> feature_maps = outputs.feature_maps
|
||||
>>> list(feature_maps[-1].shape)
|
||||
[1, 768, 7, 7]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
|
||||
outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=None,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
output_hidden_states_before_downsampling=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs.reshaped_hidden_states
|
||||
|
||||
feature_maps = ()
|
||||
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||||
if stage in self.out_features:
|
||||
batch_size, num_channels, height, width = hidden_state.shape
|
||||
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
||||
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
||||
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
||||
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
||||
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
||||
feature_maps += (hidden_state,)
|
||||
|
||||
if not return_dict:
|
||||
output = (feature_maps,)
|
||||
if output_hidden_states:
|
||||
output += (outputs.hidden_states,)
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -817,14 +817,15 @@ class Swinv2Stage(nn.Module):
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
|
||||
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (hidden_states, output_dimensions)
|
||||
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
@ -865,9 +866,9 @@ class Swinv2Encoder(nn.Module):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[Tuple, Swinv2EncoderOutput]:
|
||||
all_input_dimensions = ()
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_reshaped_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@ -898,12 +899,22 @@ class Swinv2Encoder(nn.Module):
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[1]
|
||||
hidden_states_before_downsampling = layer_outputs[1]
|
||||
output_dimensions = layer_outputs[2]
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
all_input_dimensions += (input_dimensions,)
|
||||
|
||||
if output_hidden_states:
|
||||
if output_hidden_states and output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
# here we use the original (not downsampled) height and width
|
||||
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
||||
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
||||
)
|
||||
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||||
all_hidden_states += (hidden_states_before_downsampling,)
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||||
batch_size, _, hidden_size = hidden_states.shape
|
||||
# rearrange b (h w) c -> b c h w
|
||||
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||||
@ -912,7 +923,7 @@ class Swinv2Encoder(nn.Module):
|
||||
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions += layer_outputs[2:]
|
||||
all_self_attentions += layer_outputs[3:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
@ -5243,6 +5243,13 @@ class SqueezeBertPreTrainedModel(metaclass=DummyObject):
|
||||
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class SwinBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SwinForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -30,7 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
|
||||
from transformers import SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
|
||||
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
@ -66,6 +66,7 @@ class SwinModelTester:
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
out_features=["stage1", "stage2"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -91,6 +92,7 @@ class SwinModelTester:
|
||||
self.use_labels = use_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
self.out_features = out_features
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@ -123,6 +125,7 @@ class SwinModelTester:
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@ -136,6 +139,33 @@ class SwinModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = SwinBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = SwinBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
|
||||
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||
model = SwinForMaskedImageModeling(config=config)
|
||||
model.to(torch_device)
|
||||
@ -190,6 +220,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
SwinModel,
|
||||
SwinBackbone,
|
||||
SwinForImageClassification,
|
||||
SwinForMaskedImageModeling,
|
||||
)
|
||||
@ -222,6 +253,10 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_for_masked_image_modeling(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||
@ -230,8 +265,12 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Swin does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
# Swin does not use inputs_embeds
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Swin Transformer does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
@ -299,11 +338,8 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
else:
|
||||
# also another +1 for reshaped_hidden_states
|
||||
added_hidden_states = 2
|
||||
# also another +1 for reshaped_hidden_states
|
||||
added_hidden_states = 1 if model_class.__name__ == "SwinBackbone" else 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
@ -344,17 +380,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
if not model_class.__name__ == "SwinBackbone":
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -681,6 +681,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
"NatBackbone",
|
||||
"MaskFormerSwinConfig",
|
||||
"MaskFormerSwinModel",
|
||||
"SwinBackbone",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user