mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[MaskFormer] Add support for ResNet backbone (#20483)
* Add SwinBackbone * Add hidden_states_before_downsampling support * Fix Swin tests * Improve conversion script * Add id2label mappings * Add vistas mapping * Update comments * Fix backbone * Improve tests * Extend conversion script * Add Swin conversion script * Fix style * Revert config attribute * Remove SwinBackbone from main init * Remove unused attribute * Use encoder for ResNet backbone * Improve conversion script and add integration test * Apply suggestion Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
6c1a0b3931
commit
b610c47f89
@ -18,7 +18,7 @@ from typing import Dict, Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..detr import DetrConfig
|
||||
from ..swin import SwinConfig
|
||||
|
||||
@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
|
||||
"""
|
||||
model_type = "maskformer"
|
||||
attribute_map = {"hidden_size": "mask_feature_size"}
|
||||
backbones_supported = ["swin"]
|
||||
backbones_supported = ["resnet", "swin"]
|
||||
decoders_supported = ["detr"]
|
||||
|
||||
def __init__(
|
||||
@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
|
||||
num_heads=[4, 8, 16, 32],
|
||||
window_size=12,
|
||||
drop_path_rate=0.3,
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
)
|
||||
else:
|
||||
backbone_model_type = backbone_config.pop("model_type")
|
||||
# verify that the backbone is supported
|
||||
backbone_model_type = (
|
||||
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
|
||||
)
|
||||
if backbone_model_type not in self.backbones_supported:
|
||||
raise ValueError(
|
||||
f"Backbone {backbone_model_type} not supported, please use one of"
|
||||
f" {','.join(self.backbones_supported)}"
|
||||
)
|
||||
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
|
||||
if isinstance(backbone_config, dict):
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
if decoder_config is None:
|
||||
# fall back to https://huggingface.co/facebook/detr-resnet-50
|
||||
decoder_config = DetrConfig()
|
||||
else:
|
||||
decoder_type = decoder_config.pop("model_type")
|
||||
# verify that the decoder is supported
|
||||
decoder_type = (
|
||||
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
|
||||
)
|
||||
if decoder_type not in self.decoders_supported:
|
||||
raise ValueError(
|
||||
f"Transformer Decoder {decoder_type} not supported, please use one of"
|
||||
f" {','.join(self.decoders_supported)}"
|
||||
)
|
||||
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
|
||||
if isinstance(decoder_config, dict):
|
||||
config_class = CONFIG_MAPPING[decoder_type]
|
||||
decoder_config = config_class.from_dict(decoder_config)
|
||||
|
||||
self.backbone_config = backbone_config
|
||||
self.decoder_config = decoder_config
|
||||
@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
|
||||
[`MaskFormerConfig`]: An instance of a configuration object
|
||||
"""
|
||||
return cls(
|
||||
backbone_config=backbone_config.to_dict(),
|
||||
decoder_config=decoder_config.to_dict(),
|
||||
backbone_config=backbone_config,
|
||||
decoder_config=decoder_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
|
||||
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -0,0 +1,390 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert MaskFormer checkpoints with ResNet backbone from the original repository. URL:
|
||||
https://github.com/facebookresearch/MaskFormer"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, ResNetConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_maskformer_config(model_name: str):
|
||||
if "resnet101c" in model_name:
|
||||
# TODO add support for ResNet-C backbone, which uses a "deeplab" stem
|
||||
raise NotImplementedError("To do")
|
||||
elif "resnet101" in model_name:
|
||||
backbone_config = ResNetConfig.from_pretrained(
|
||||
"microsoft/resnet-101", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||
)
|
||||
else:
|
||||
backbone_config = ResNetConfig.from_pretrained(
|
||||
"microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||
)
|
||||
config = MaskFormerConfig(backbone_config=backbone_config)
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "ade20k-full" in model_name:
|
||||
config.num_labels = 847
|
||||
filename = "maskformer-ade20k-full-id2label.json"
|
||||
elif "ade" in model_name:
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
elif "coco-stuff" in model_name:
|
||||
config.num_labels = 171
|
||||
filename = "maskformer-coco-stuff-id2label.json"
|
||||
elif "coco" in model_name:
|
||||
# TODO
|
||||
config.num_labels = 133
|
||||
filename = "coco-panoptic-id2label.json"
|
||||
elif "cityscapes" in model_name:
|
||||
config.num_labels = 19
|
||||
filename = "cityscapes-id2label.json"
|
||||
elif "vistas" in model_name:
|
||||
config.num_labels = 65
|
||||
filename = "mapillary-vistas-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_rename_keys(config):
|
||||
rename_keys = []
|
||||
# stem
|
||||
# fmt: off
|
||||
rename_keys.append(("backbone.stem.conv1.weight", "model.pixel_level_module.encoder.embedder.embedder.convolution.weight"))
|
||||
rename_keys.append(("backbone.stem.conv1.norm.weight", "model.pixel_level_module.encoder.embedder.embedder.normalization.weight"))
|
||||
rename_keys.append(("backbone.stem.conv1.norm.bias", "model.pixel_level_module.encoder.embedder.embedder.normalization.bias"))
|
||||
rename_keys.append(("backbone.stem.conv1.norm.running_mean", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_mean"))
|
||||
rename_keys.append(("backbone.stem.conv1.norm.running_var", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_var"))
|
||||
# fmt: on
|
||||
# stages
|
||||
for stage_idx in range(len(config.backbone_config.depths)):
|
||||
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
||||
# shortcut
|
||||
if layer_idx == 0:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.weight",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.weight",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.bias",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_mean",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_var",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
|
||||
)
|
||||
)
|
||||
# 3 convs
|
||||
for i in range(3):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.weight",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.weight",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.bias",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_mean",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_var",
|
||||
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
|
||||
)
|
||||
)
|
||||
|
||||
# FPN
|
||||
# fmt: off
|
||||
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
|
||||
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
|
||||
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
|
||||
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
|
||||
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
|
||||
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
|
||||
# fmt: on
|
||||
|
||||
# Transformer decoder
|
||||
# fmt: off
|
||||
for idx in range(config.decoder_config.decoder_layers):
|
||||
# self-attention out projection
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
|
||||
# cross-attention out projection
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
|
||||
# MLP 1
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
|
||||
# MLP 2
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
|
||||
# layernorm 1 (self-attention layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
|
||||
# layernorm 2 (cross-attention layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
|
||||
# layernorm 3 (final layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
|
||||
# fmt: on
|
||||
|
||||
# heads on top
|
||||
# fmt: off
|
||||
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
|
||||
|
||||
for i in range(3):
|
||||
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
|
||||
# fmt: on
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_decoder_q_k_v(state_dict, config):
|
||||
# fmt: off
|
||||
hidden_size = config.decoder_config.hidden_size
|
||||
for idx in range(config.decoder_config.decoder_layers):
|
||||
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||
# fmt: on
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img() -> torch.Tensor:
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_maskformer_checkpoint(
|
||||
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our MaskFormer structure.
|
||||
"""
|
||||
config = get_maskformer_config(model_name)
|
||||
|
||||
# load original state_dict
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
state_dict = data["model"]
|
||||
|
||||
# rename keys
|
||||
rename_keys = create_rename_keys(config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_decoder_q_k_v(state_dict, config)
|
||||
|
||||
# update to torch tensors
|
||||
for key, value in state_dict.items():
|
||||
state_dict[key] = torch.from_numpy(value)
|
||||
|
||||
# load 🤗 model
|
||||
model = MaskFormerForInstanceSegmentation(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# verify results
|
||||
image = prepare_img()
|
||||
if "vistas" in model_name:
|
||||
ignore_index = 65
|
||||
elif "cityscapes" in model_name:
|
||||
ignore_index = 65535
|
||||
else:
|
||||
ignore_index = 255
|
||||
reduce_labels = True if "ade" in model_name else False
|
||||
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
||||
|
||||
inputs = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
if model_name == "maskformer-resnet50-ade":
|
||||
expected_logits = torch.tensor(
|
||||
[[6.7710, -0.1452, -3.5687], [1.9165, -1.0010, -1.8614], [3.6209, -0.2950, -1.3813]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet101-ade":
|
||||
expected_logits = torch.tensor(
|
||||
[[4.0381, -1.1483, -1.9688], [2.7083, -1.9147, -2.2555], [3.4367, -1.3711, -2.1609]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet50-coco-stuff":
|
||||
expected_logits = torch.tensor(
|
||||
[[3.2309, -3.0481, -2.8695], [5.4986, -5.4242, -2.4211], [6.2100, -5.2279, -2.7786]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet101-coco-stuff":
|
||||
expected_logits = torch.tensor(
|
||||
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet101-cityscapes":
|
||||
expected_logits = torch.tensor(
|
||||
[[-1.8861, -1.5465, 0.6749], [-2.3677, -1.6707, -0.0867], [-2.2314, -1.9530, -0.9132]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet50-vistas":
|
||||
expected_logits = torch.tensor(
|
||||
[[-6.3917, -1.5216, -1.1392], [-5.5335, -4.5318, -1.8339], [-4.3576, -4.0301, 0.2162]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet50-ade20k-full":
|
||||
expected_logits = torch.tensor(
|
||||
[[3.6146, -1.9367, -3.2534], [4.0099, 0.2027, -2.7576], [3.3913, -2.3644, -3.9519]]
|
||||
)
|
||||
elif model_name == "maskformer-resnet101-ade20k-full":
|
||||
expected_logits = torch.tensor(
|
||||
[[3.2211, -1.6550, -2.7605], [2.8559, -2.4512, -2.9574], [2.6331, -2.6775, -2.1844]]
|
||||
)
|
||||
|
||||
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and feature extractor of {model_name} to {pytorch_dump_folder_path}")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and feature extractor of {model_name} to the hub...")
|
||||
model.push_to_hub(f"facebook/{model_name}")
|
||||
feature_extractor.push_to_hub(f"facebook/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="maskformer-resnet50-ade",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"maskformer-resnet50-ade",
|
||||
"maskformer-resnet101-ade",
|
||||
"maskformer-resnet50-coco-stuff",
|
||||
"maskformer-resnet101-coco-stuff",
|
||||
"maskformer-resnet101-cityscapes",
|
||||
"maskformer-resnet50-vistas",
|
||||
"maskformer-resnet50-ade20k-full",
|
||||
"maskformer-resnet101-ade20k-full",
|
||||
],
|
||||
help=("Name of the MaskFormer model you'd like to convert",),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("Path to the original pickle file (.pkl) of the original checkpoint.",),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_maskformer_checkpoint(
|
||||
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||
)
|
@ -0,0 +1,333 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert MaskFormer checkpoints with Swin backbone from the original repository. URL:
|
||||
https://github.com/facebookresearch/MaskFormer"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, SwinConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_maskformer_config(model_name: str):
|
||||
backbone_config = SwinConfig.from_pretrained(
|
||||
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||
)
|
||||
config = MaskFormerConfig(backbone_config=backbone_config)
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "ade20k-full" in model_name:
|
||||
# this should be ok
|
||||
config.num_labels = 847
|
||||
filename = "maskformer-ade20k-full-id2label.json"
|
||||
elif "ade" in model_name:
|
||||
# this should be ok
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
elif "coco-stuff" in model_name:
|
||||
# this should be ok
|
||||
config.num_labels = 171
|
||||
filename = "maskformer-coco-stuff-id2label.json"
|
||||
elif "coco" in model_name:
|
||||
# TODO
|
||||
config.num_labels = 133
|
||||
filename = "coco-panoptic-id2label.json"
|
||||
elif "cityscapes" in model_name:
|
||||
# this should be ok
|
||||
config.num_labels = 19
|
||||
filename = "cityscapes-id2label.json"
|
||||
elif "vistas" in model_name:
|
||||
# this should be ok
|
||||
config.num_labels = 65
|
||||
filename = "mapillary-vistas-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_rename_keys(config):
|
||||
rename_keys = []
|
||||
# stem
|
||||
# fmt: off
|
||||
rename_keys.append(("backbone.patch_embed.proj.weight", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.weight"))
|
||||
rename_keys.append(("backbone.patch_embed.proj.bias", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.bias"))
|
||||
rename_keys.append(("backbone.patch_embed.norm.weight", "model.pixel_level_module.encoder.model.embeddings.norm.weight"))
|
||||
rename_keys.append(("backbone.patch_embed.norm.bias", "model.pixel_level_module.encoder.model.embeddings.norm.bias"))
|
||||
# stages
|
||||
for i in range(len(config.backbone_config.depths)):
|
||||
for j in range(config.backbone_config.depths[i]):
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_bias_table", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_index", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
|
||||
|
||||
if i < 3:
|
||||
rename_keys.append((f"backbone.layers.{i}.downsample.reduction.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.reduction.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.downsample.norm.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.weight"))
|
||||
rename_keys.append((f"backbone.layers.{i}.downsample.norm.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.bias"))
|
||||
rename_keys.append((f"backbone.norm{i}.weight", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.weight"))
|
||||
rename_keys.append((f"backbone.norm{i}.bias", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.bias"))
|
||||
|
||||
# FPN
|
||||
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
|
||||
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
|
||||
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
|
||||
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
|
||||
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
|
||||
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
|
||||
|
||||
# Transformer decoder
|
||||
for idx in range(config.decoder_config.decoder_layers):
|
||||
# self-attention out projection
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
|
||||
# cross-attention out projection
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
|
||||
# MLP 1
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
|
||||
# MLP 2
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
|
||||
# layernorm 1 (self-attention layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
|
||||
# layernorm 2 (cross-attention layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
|
||||
# layernorm 3 (final layernorm)
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
|
||||
|
||||
# heads on top
|
||||
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
|
||||
|
||||
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
|
||||
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
|
||||
|
||||
for i in range(3):
|
||||
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
|
||||
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
|
||||
# fmt: on
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_swin_q_k_v(state_dict, backbone_config):
|
||||
num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
|
||||
for i in range(len(backbone_config.depths)):
|
||||
dim = num_features[i]
|
||||
for j in range(backbone_config.depths[i]):
|
||||
# fmt: off
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.weight")
|
||||
in_proj_bias = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
|
||||
dim : dim * 2, :
|
||||
]
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
|
||||
dim : dim * 2
|
||||
]
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
|
||||
-dim :, :
|
||||
]
|
||||
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
|
||||
# fmt: on
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_decoder_q_k_v(state_dict, config):
|
||||
# fmt: off
|
||||
hidden_size = config.decoder_config.hidden_size
|
||||
for idx in range(config.decoder_config.decoder_layers):
|
||||
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||
# fmt: on
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img() -> torch.Tensor:
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_maskformer_checkpoint(
|
||||
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our MaskFormer structure.
|
||||
"""
|
||||
config = get_maskformer_config(model_name)
|
||||
|
||||
# load original state_dict
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
state_dict = data["model"]
|
||||
|
||||
# for name, param in state_dict.items():
|
||||
# print(name, param.shape)
|
||||
|
||||
# rename keys
|
||||
rename_keys = create_rename_keys(config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_swin_q_k_v(state_dict, config.backbone_config)
|
||||
read_in_decoder_q_k_v(state_dict, config)
|
||||
|
||||
# update to torch tensors
|
||||
for key, value in state_dict.items():
|
||||
state_dict[key] = torch.from_numpy(value)
|
||||
|
||||
# load 🤗 model
|
||||
model = MaskFormerForInstanceSegmentation(config)
|
||||
model.eval()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.shape)
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == [
|
||||
"model.pixel_level_module.encoder.model.layernorm.weight",
|
||||
"model.pixel_level_module.encoder.model.layernorm.bias",
|
||||
]
|
||||
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
|
||||
|
||||
# verify results
|
||||
image = prepare_img()
|
||||
if "vistas" in model_name:
|
||||
ignore_index = 65
|
||||
elif "cityscapes" in model_name:
|
||||
ignore_index = 65535
|
||||
else:
|
||||
ignore_index = 255
|
||||
reduce_labels = True if "ade" in model_name else False
|
||||
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
||||
|
||||
inputs = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
print("Logits:", outputs.class_queries_logits[0, :3, :3])
|
||||
|
||||
if model_name == "maskformer-swin-tiny-ade":
|
||||
expected_logits = torch.tensor(
|
||||
[[3.6353, -4.4770, -2.6065], [0.5081, -4.2394, -3.5343], [2.1909, -5.0353, -1.9323]]
|
||||
)
|
||||
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"nielsr/{model_name}")
|
||||
feature_extractor.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="maskformer-swin-tiny-ade",
|
||||
type=str,
|
||||
help=("Name of the MaskFormer model you'd like to convert",),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="/Users/nielsrogge/Documents/MaskFormer_checkpoints/MaskFormer-Swin-tiny-ADE20k/model.pkl",
|
||||
type=str,
|
||||
help="Path to the original state dict (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_maskformer_checkpoint(
|
||||
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||
)
|
@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
|
||||
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
|
||||
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
|
||||
class RegNetPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
|
@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, ResNetModel):
|
||||
if isinstance(module, (ResNetModel, ResNetBackbone)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.resnet = ResNetModel(config)
|
||||
self.embedder = ResNetEmbeddings(config)
|
||||
self.encoder = ResNetEncoder(config)
|
||||
|
||||
self.out_features = config.out_features
|
||||
|
||||
@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
|
||||
embedding_output = self.embedder(pixel_values)
|
||||
|
||||
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
|
@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig):
|
||||
encoder_stride (`int`, `optional`, defaults to 32):
|
||||
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
|
||||
|
||||
Example:
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import SwinConfig, SwinModel
|
||||
|
@ -320,16 +320,16 @@ def prepare_img():
|
||||
@require_vision
|
||||
@slow
|
||||
class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def model_checkpoints(self):
|
||||
return "facebook/maskformer-swin-small-coco"
|
||||
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
|
||||
return (
|
||||
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
|
||||
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_inference_instance_segmentation_head(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
# masks_queries_logits
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
self.assertEqual(
|
||||
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
|
||||
masks_queries_logits.shape,
|
||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||
)
|
||||
expected_slice = [
|
||||
[-1.3737124, -1.7724937, -1.9364233],
|
||||
@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
|
||||
self.assertEqual(
|
||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[1.6512e00, -5.2572e00, -3.3519e00],
|
||||
@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_inference_instance_segmentation_head_resnet_backbone(self):
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
inputs_shape = inputs["pixel_values"].shape
|
||||
# check size is divisible by 32
|
||||
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
|
||||
# check size
|
||||
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# masks_queries_logits
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
self.assertEqual(
|
||||
masks_queries_logits.shape,
|
||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||
)
|
||||
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
|
||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
self.assertEqual(
|
||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
inputs = feature_extractor(
|
||||
|
Loading…
Reference in New Issue
Block a user