[MaskFormer] Add support for ResNet backbone (#20483)

* Add SwinBackbone

* Add hidden_states_before_downsampling support

* Fix Swin tests

* Improve conversion script

* Add id2label mappings

* Add vistas mapping

* Update comments

* Fix backbone

* Improve tests

* Extend conversion script

* Add Swin conversion script

* Fix style

* Revert config attribute

* Remove SwinBackbone from main init

* Remove unused attribute

* Use encoder for ResNet backbone

* Improve conversion script and add integration test

* Apply suggestion

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-12-07 09:42:38 +01:00 committed by GitHub
parent 6c1a0b3931
commit b610c47f89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 808 additions and 24 deletions

View File

@ -18,7 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig
from ..swin import SwinConfig
@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
"""
model_type = "maskformer"
attribute_map = {"hidden_size": "mask_feature_size"}
backbones_supported = ["swin"]
backbones_supported = ["resnet", "swin"]
decoders_supported = ["detr"]
def __init__(
@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
num_heads=[4, 8, 16, 32],
window_size=12,
drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
else:
backbone_model_type = backbone_config.pop("model_type")
# verify that the backbone is supported
backbone_model_type = (
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
)
if backbone_model_type not in self.backbones_supported:
raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of"
f" {','.join(self.backbones_supported)}"
)
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
if isinstance(backbone_config, dict):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50
decoder_config = DetrConfig()
else:
decoder_type = decoder_config.pop("model_type")
# verify that the decoder is supported
decoder_type = (
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
)
if decoder_type not in self.decoders_supported:
raise ValueError(
f"Transformer Decoder {decoder_type} not supported, please use one of"
f" {','.join(self.decoders_supported)}"
)
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
if isinstance(decoder_config, dict):
config_class = CONFIG_MAPPING[decoder_type]
decoder_config = config_class.from_dict(decoder_config)
self.backbone_config = backbone_config
self.decoder_config = decoder_config
@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
[`MaskFormerConfig`]: An instance of a configuration object
"""
return cls(
backbone_config=backbone_config.to_dict(),
decoder_config=decoder_config.to_dict(),
backbone_config=backbone_config,
decoder_config=decoder_config,
**kwargs,
)

View File

@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
Example:

View File

@ -0,0 +1,390 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MaskFormer checkpoints with ResNet backbone from the original repository. URL:
https://github.com/facebookresearch/MaskFormer"""
import argparse
import json
import pickle
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, ResNetConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
if "resnet101c" in model_name:
# TODO add support for ResNet-C backbone, which uses a "deeplab" stem
raise NotImplementedError("To do")
elif "resnet101" in model_name:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-101", out_features=["stage1", "stage2", "stage3", "stage4"]
)
else:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
# TODO
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.stem.conv1.weight", "model.pixel_level_module.encoder.embedder.embedder.convolution.weight"))
rename_keys.append(("backbone.stem.conv1.norm.weight", "model.pixel_level_module.encoder.embedder.embedder.normalization.weight"))
rename_keys.append(("backbone.stem.conv1.norm.bias", "model.pixel_level_module.encoder.embedder.embedder.normalization.bias"))
rename_keys.append(("backbone.stem.conv1.norm.running_mean", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_mean"))
rename_keys.append(("backbone.stem.conv1.norm.running_var", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_var"))
# fmt: on
# stages
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
# shortcut
if layer_idx == 0:
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.bias",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_mean",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_var",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
)
)
# 3 convs
for i in range(3):
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.bias",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_mean",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_var",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
)
)
# FPN
# fmt: off
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
# fmt: on
# Transformer decoder
# fmt: off
for idx in range(config.decoder_config.decoder_layers):
# self-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
# cross-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
# MLP 1
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
# MLP 2
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
# layernorm 1 (self-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
# layernorm 2 (cross-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
# layernorm 3 (final layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
# fmt: on
# heads on top
# fmt: off
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
for i in range(3):
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_decoder_q_k_v(state_dict, config):
# fmt: off
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# fmt: on
# We will verify our results on an image of cute cats
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
# load original state_dict
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
# rename keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_decoder_q_k_v(state_dict, config)
# update to torch tensors
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
# load 🤗 model
model = MaskFormerForInstanceSegmentation(config)
model.eval()
model.load_state_dict(state_dict)
# verify results
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
if model_name == "maskformer-resnet50-ade":
expected_logits = torch.tensor(
[[6.7710, -0.1452, -3.5687], [1.9165, -1.0010, -1.8614], [3.6209, -0.2950, -1.3813]]
)
elif model_name == "maskformer-resnet101-ade":
expected_logits = torch.tensor(
[[4.0381, -1.1483, -1.9688], [2.7083, -1.9147, -2.2555], [3.4367, -1.3711, -2.1609]]
)
elif model_name == "maskformer-resnet50-coco-stuff":
expected_logits = torch.tensor(
[[3.2309, -3.0481, -2.8695], [5.4986, -5.4242, -2.4211], [6.2100, -5.2279, -2.7786]]
)
elif model_name == "maskformer-resnet101-coco-stuff":
expected_logits = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
)
elif model_name == "maskformer-resnet101-cityscapes":
expected_logits = torch.tensor(
[[-1.8861, -1.5465, 0.6749], [-2.3677, -1.6707, -0.0867], [-2.2314, -1.9530, -0.9132]]
)
elif model_name == "maskformer-resnet50-vistas":
expected_logits = torch.tensor(
[[-6.3917, -1.5216, -1.1392], [-5.5335, -4.5318, -1.8339], [-4.3576, -4.0301, 0.2162]]
)
elif model_name == "maskformer-resnet50-ade20k-full":
expected_logits = torch.tensor(
[[3.6146, -1.9367, -3.2534], [4.0099, 0.2027, -2.7576], [3.3913, -2.3644, -3.9519]]
)
elif model_name == "maskformer-resnet101-ade20k-full":
expected_logits = torch.tensor(
[[3.2211, -1.6550, -2.7605], [2.8559, -2.4512, -2.9574], [2.6331, -2.6775, -2.1844]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and feature extractor of {model_name} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and feature extractor of {model_name} to the hub...")
model.push_to_hub(f"facebook/{model_name}")
feature_extractor.push_to_hub(f"facebook/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="maskformer-resnet50-ade",
type=str,
required=True,
choices=[
"maskformer-resnet50-ade",
"maskformer-resnet101-ade",
"maskformer-resnet50-coco-stuff",
"maskformer-resnet101-coco-stuff",
"maskformer-resnet101-cityscapes",
"maskformer-resnet50-vistas",
"maskformer-resnet50-ade20k-full",
"maskformer-resnet101-ade20k-full",
],
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help=("Path to the original pickle file (.pkl) of the original checkpoint.",),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)

View File

@ -0,0 +1,333 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MaskFormer checkpoints with Swin backbone from the original repository. URL:
https://github.com/facebookresearch/MaskFormer"""
import argparse
import json
import pickle
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, SwinConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
# this should be ok
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
# this should be ok
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
# this should be ok
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
# TODO
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
# this should be ok
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
# this should be ok
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.patch_embed.proj.weight", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("backbone.patch_embed.proj.bias", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("backbone.patch_embed.norm.weight", "model.pixel_level_module.encoder.model.embeddings.norm.weight"))
rename_keys.append(("backbone.patch_embed.norm.bias", "model.pixel_level_module.encoder.model.embeddings.norm.bias"))
# stages
for i in range(len(config.backbone_config.depths)):
for j in range(config.backbone_config.depths[i]):
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_bias_table", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_index", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
if i < 3:
rename_keys.append((f"backbone.layers.{i}.downsample.reduction.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.reduction.weight"))
rename_keys.append((f"backbone.layers.{i}.downsample.norm.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.weight"))
rename_keys.append((f"backbone.layers.{i}.downsample.norm.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.bias"))
rename_keys.append((f"backbone.norm{i}.weight", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.weight"))
rename_keys.append((f"backbone.norm{i}.bias", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.bias"))
# FPN
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
# Transformer decoder
for idx in range(config.decoder_config.decoder_layers):
# self-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
# cross-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
# MLP 1
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
# MLP 2
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
# layernorm 1 (self-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
# layernorm 2 (cross-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
# layernorm 3 (final layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
# heads on top
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
for i in range(3):
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_swin_q_k_v(state_dict, backbone_config):
num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
for i in range(len(backbone_config.depths)):
dim = num_features[i]
for j in range(backbone_config.depths[i]):
# fmt: off
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
dim : dim * 2, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
dim : dim * 2
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
-dim :, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
# fmt: on
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_decoder_q_k_v(state_dict, config):
# fmt: off
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# fmt: on
# We will verify our results on an image of cute cats
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
# load original state_dict
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
# for name, param in state_dict.items():
# print(name, param.shape)
# rename keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_swin_q_k_v(state_dict, config.backbone_config)
read_in_decoder_q_k_v(state_dict, config)
# update to torch tensors
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
# load 🤗 model
model = MaskFormerForInstanceSegmentation(config)
model.eval()
for name, param in model.named_parameters():
print(name, param.shape)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert missing_keys == [
"model.pixel_level_module.encoder.model.layernorm.weight",
"model.pixel_level_module.encoder.model.layernorm.bias",
]
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
# verify results
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
print("Logits:", outputs.class_queries_logits[0, :3, :3])
if model_name == "maskformer-swin-tiny-ade":
expected_logits = torch.tensor(
[[3.6353, -4.4770, -2.6065], [0.5081, -4.2394, -3.5343], [2.1909, -5.0353, -1.9323]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and feature extractor to the hub...")
model.push_to_hub(f"nielsr/{model_name}")
feature_extractor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="maskformer-swin-tiny-ade",
type=str,
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/MaskFormer_checkpoints/MaskFormer-Swin-tiny-ADE20k/model.pkl",
type=str,
help="Path to the original state dict (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)

View File

@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
class RegNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")

View File

@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ResNetModel):
if isinstance(module, (ResNetModel, ResNetBackbone)):
module.gradient_checkpointing = value
@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
super().__init__(config)
self.stage_names = config.stage_names
self.resnet = ResNetModel(config)
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features
@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
embedding_output = self.embedder(pixel_values)
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states

View File

@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig):
encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
Example:
Example:
```python
>>> from transformers import SwinConfig, SwinModel

View File

@ -320,16 +320,16 @@ def prepare_img():
@require_vision
@slow
class MaskFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/maskformer-swin-small-coco"
@cached_property
def default_feature_extractor(self):
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
return (
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
if is_vision_available()
else None
)
def test_inference_no_head(self):
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
)
def test_inference_instance_segmentation_head(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [
[-1.3737124, -1.7724937, -1.9364233],
@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[
[1.6512e00, -5.2572e00, -3.3519e00],
@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_instance_segmentation_head_resnet_backbone(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(