mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[DETR and friends] Use AutoBackbone as alternative to timm (#20833)
* First draft * More improvements * Add conversion script * More improvements * Add docs * Address review * Rename class to ConvEncoder * Address review * Apply suggestion * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update all DETR friends * Add corresponding test * Improve test * Fix bug * Add more tests * Set out_features to last stage by default Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c8d719ff7e
commit
91ff7efeeb
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -44,6 +45,12 @@ class ConditionalDetrConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
||||
API.
|
||||
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
||||
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
||||
case it will default to `ResNetConfig()`.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
num_queries (`int`, *optional*, defaults to 100):
|
||||
@ -87,13 +94,14 @@ class ConditionalDetrConfig(PretrainedConfig):
|
||||
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
||||
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
||||
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||
list of all available models, see [this
|
||||
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
|
||||
backbone from the timm package. For a list of all available models, see [this
|
||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
|
||||
dilation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||
`use_timm_backbone` = `True`.
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
Relative weight of the classification error in the Hungarian matching cost.
|
||||
bbox_cost (`float`, *optional*, defaults to 5):
|
||||
@ -136,6 +144,8 @@ class ConditionalDetrConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_timm_backbone=True,
|
||||
backbone_config=None,
|
||||
num_channels=3,
|
||||
num_queries=300,
|
||||
encoder_layers=6,
|
||||
@ -172,6 +182,20 @@ class ConditionalDetrConfig(PretrainedConfig):
|
||||
focal_alpha=0.25,
|
||||
**kwargs
|
||||
):
|
||||
if backbone_config is not None and use_timm_backbone:
|
||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||
|
||||
if not use_timm_backbone:
|
||||
if backbone_config is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_config = backbone_config
|
||||
self.num_channels = num_channels
|
||||
self.num_queries = num_queries
|
||||
self.d_model = d_model
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_conditional_detr import ConditionalDetrConfig
|
||||
|
||||
|
||||
@ -326,46 +327,57 @@ def replace_batch_norm(m, name=""):
|
||||
replace_batch_norm(ch, n)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder
|
||||
class ConditionalDetrTimmConvEncoder(nn.Module):
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
|
||||
class ConditionalDetrConvEncoder(nn.Module):
|
||||
"""
|
||||
Convolutional encoder (backbone) from the timm library.
|
||||
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
||||
|
||||
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
if dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
self.config = config
|
||||
|
||||
requires_backends(self, ["timm"])
|
||||
if config.use_timm_backbone:
|
||||
requires_backends(self, ["timm"])
|
||||
kwargs = {}
|
||||
if config.dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
backbone = create_model(
|
||||
config.backbone,
|
||||
pretrained=config.use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=config.num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
backbone = AutoBackbone.from_config(config.backbone_config)
|
||||
|
||||
backbone = create_model(
|
||||
name,
|
||||
pretrained=use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
self.model = backbone
|
||||
self.intermediate_channel_sizes = self.model.feature_info.channels()
|
||||
self.intermediate_channel_sizes = (
|
||||
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
||||
)
|
||||
|
||||
if "resnet" in name:
|
||||
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
|
||||
if "resnet" in backbone_model_type:
|
||||
for name, parameter in self.model.named_parameters():
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
if config.use_timm_backbone:
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
else:
|
||||
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
||||
# send pixel_values through the model to get list of feature maps
|
||||
features = self.model(pixel_values)
|
||||
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
||||
|
||||
out = []
|
||||
for feature_map in features:
|
||||
@ -1468,9 +1480,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = ConditionalDetrTimmConvEncoder(
|
||||
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
|
||||
)
|
||||
backbone = ConditionalDetrConvEncoder(config)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = ConditionalDetrConvModel(backbone, position_embeddings)
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -37,6 +38,14 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
||||
API.
|
||||
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
||||
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
||||
case it will default to `ResNetConfig()`.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
num_queries (`int`, *optional*, defaults to 300):
|
||||
Number of object queries, i.e. detection slots. This is the maximal number of objects
|
||||
[`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use
|
||||
@ -79,11 +88,14 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
||||
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
||||
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||
list of all available models, see [this
|
||||
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
|
||||
backbone from the timm package. For a list of all available models, see [this
|
||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
|
||||
dilation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||
`use_timm_backbone` = `True`.
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
Relative weight of the classification error in the Hungarian matching cost.
|
||||
bbox_cost (`float`, *optional*, defaults to 5):
|
||||
@ -139,6 +151,9 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_timm_backbone=True,
|
||||
backbone_config=None,
|
||||
num_channels=3,
|
||||
num_queries=300,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=6,
|
||||
@ -161,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
auxiliary_loss=False,
|
||||
position_embedding_type="sine",
|
||||
backbone="resnet50",
|
||||
use_pretrained_backbone=True,
|
||||
dilation=False,
|
||||
num_feature_levels=4,
|
||||
encoder_n_points=4,
|
||||
@ -179,6 +195,20 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
focal_alpha=0.25,
|
||||
**kwargs
|
||||
):
|
||||
if backbone_config is not None and use_timm_backbone:
|
||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||
|
||||
if not use_timm_backbone:
|
||||
if backbone_config is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_config = backbone_config
|
||||
self.num_channels = num_channels
|
||||
self.num_queries = num_queries
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
@ -199,6 +229,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
self.auxiliary_loss = auxiliary_loss
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.backbone = backbone
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.dilation = dilation
|
||||
# deformable attributes
|
||||
self.num_feature_levels = num_feature_levels
|
||||
|
@ -43,6 +43,7 @@ from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...utils import is_ninja_available, logging
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_deformable_detr import DeformableDetrConfig
|
||||
from .load_custom import load_cuda_kernels
|
||||
|
||||
@ -371,45 +372,57 @@ def replace_batch_norm(m, name=""):
|
||||
replace_batch_norm(ch, n)
|
||||
|
||||
|
||||
class DeformableDetrTimmConvEncoder(nn.Module):
|
||||
class DeformableDetrConvEncoder(nn.Module):
|
||||
"""
|
||||
Convolutional encoder (backbone) from the timm library.
|
||||
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
||||
|
||||
nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
if config.dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
self.config = config
|
||||
|
||||
requires_backends(self, ["timm"])
|
||||
if config.use_timm_backbone:
|
||||
requires_backends(self, ["timm"])
|
||||
kwargs = {}
|
||||
if config.dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
backbone = create_model(
|
||||
config.backbone,
|
||||
pretrained=config.use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
|
||||
in_chans=config.num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
backbone = AutoBackbone.from_config(config.backbone_config)
|
||||
|
||||
out_indices = (2, 3, 4) if config.num_feature_levels > 1 else (4,)
|
||||
backbone = create_model(
|
||||
config.backbone, pretrained=True, features_only=True, out_indices=out_indices, **kwargs
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
self.model = backbone
|
||||
self.intermediate_channel_sizes = self.model.feature_info.channels()
|
||||
self.strides = self.model.feature_info.reduction()
|
||||
self.intermediate_channel_sizes = (
|
||||
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
||||
)
|
||||
|
||||
if "resnet" in config.backbone:
|
||||
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
|
||||
if "resnet" in backbone_model_type:
|
||||
for name, parameter in self.model.named_parameters():
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
if config.use_timm_backbone:
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
else:
|
||||
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
|
||||
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
||||
"""
|
||||
Outputs feature maps of latter stages C_3 through C_5 in ResNet if `config.num_feature_levels > 1`, otherwise
|
||||
outputs feature maps of C_5.
|
||||
"""
|
||||
# send pixel_values through the model to get list of feature maps
|
||||
features = self.model(pixel_values)
|
||||
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
||||
|
||||
out = []
|
||||
for feature_map in features:
|
||||
@ -1438,13 +1451,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = DeformableDetrTimmConvEncoder(config)
|
||||
backbone = DeformableDetrConvEncoder(config)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
|
||||
|
||||
# Create input projection layers
|
||||
if config.num_feature_levels > 1:
|
||||
num_backbone_outs = len(backbone.strides)
|
||||
num_backbone_outs = len(backbone.intermediate_channel_sizes)
|
||||
input_proj_list = []
|
||||
for _ in range(num_backbone_outs):
|
||||
in_channels = backbone.intermediate_channel_sizes[_]
|
||||
|
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -43,6 +44,12 @@ class DetrConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
||||
API.
|
||||
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
||||
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
||||
case it will default to `ResNetConfig()`.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
num_queries (`int`, *optional*, defaults to 100):
|
||||
@ -86,13 +93,14 @@ class DetrConfig(PretrainedConfig):
|
||||
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
||||
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
||||
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||
list of all available models, see [this
|
||||
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
|
||||
backbone from the timm package. For a list of all available models, see [this
|
||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
|
||||
dilation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||
`use_timm_backbone` = `True`.
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
Relative weight of the classification error in the Hungarian matching cost.
|
||||
bbox_cost (`float`, *optional*, defaults to 5):
|
||||
@ -133,6 +141,8 @@ class DetrConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_timm_backbone=True,
|
||||
backbone_config=None,
|
||||
num_channels=3,
|
||||
num_queries=100,
|
||||
encoder_layers=6,
|
||||
@ -168,6 +178,20 @@ class DetrConfig(PretrainedConfig):
|
||||
eos_coefficient=0.1,
|
||||
**kwargs
|
||||
):
|
||||
if backbone_config is not None and use_timm_backbone:
|
||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||
|
||||
if not use_timm_backbone:
|
||||
if backbone_config is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_config = backbone_config
|
||||
self.num_channels = num_channels
|
||||
self.num_queries = num_queries
|
||||
self.d_model = d_model
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert DETR checkpoints."""
|
||||
"""Convert DETR checkpoints with timm backbone."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
375
src/transformers/models/detr/convert_detr_to_pytorch.py
Normal file
375
src/transformers/models/detr/convert_detr_to_pytorch.py
Normal file
@ -0,0 +1,375 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert DETR checkpoints with native (Transformers) backbone."""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor, ResNetConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_detr_config(model_name):
|
||||
config = DetrConfig(use_timm_backbone=False)
|
||||
|
||||
# set backbone attributes
|
||||
if "resnet50" in model_name:
|
||||
pass
|
||||
elif "resnet101" in model_name:
|
||||
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
|
||||
else:
|
||||
raise ValueError("Model name should include either resnet50 or resnet101")
|
||||
|
||||
# set label attributes
|
||||
is_panoptic = "panoptic" in model_name
|
||||
if is_panoptic:
|
||||
config.num_labels = 250
|
||||
else:
|
||||
config.num_labels = 91
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config, is_panoptic
|
||||
|
||||
|
||||
def create_rename_keys(config):
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
rename_keys = []
|
||||
|
||||
# stem
|
||||
# fmt: off
|
||||
rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight"))
|
||||
rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight"))
|
||||
rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias"))
|
||||
rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean"))
|
||||
rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var"))
|
||||
# stages
|
||||
for stage_idx in range(len(config.backbone_config.depths)):
|
||||
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
||||
# shortcut
|
||||
if layer_idx == 0:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
|
||||
)
|
||||
)
|
||||
# 3 convs
|
||||
for i in range(3):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
|
||||
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
|
||||
)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
for i in range(config.encoder_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.encoder.layers.{i}.self_attn.out_proj.weight",
|
||||
f"encoder.layers.{i}.self_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.self_attn.out_proj.weight",
|
||||
f"decoder.layers.{i}.self_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
||||
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("input_proj.weight", "input_projection.weight"),
|
||||
("input_proj.bias", "input_projection.bias"),
|
||||
("query_embed.weight", "query_position_embeddings.weight"),
|
||||
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
||||
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
||||
("class_embed.weight", "class_labels_classifier.weight"),
|
||||
("class_embed.bias", "class_labels_classifier.bias"),
|
||||
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
||||
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
||||
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
||||
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
||||
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
||||
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(state_dict, old, new):
|
||||
val = state_dict.pop(old)
|
||||
state_dict[new] = val
|
||||
|
||||
|
||||
def read_in_q_k_v(state_dict, is_panoptic=False):
|
||||
prefix = ""
|
||||
if is_panoptic:
|
||||
prefix = "detr."
|
||||
|
||||
# first: transformer encoder
|
||||
for i in range(6):
|
||||
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
||||
for i in range(6):
|
||||
# read in weights + bias of input projection layer of self-attention
|
||||
in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
# read in weights + bias of input projection layer of cross-attention
|
||||
in_proj_weight_cross_attn = state_dict.pop(
|
||||
f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
|
||||
)
|
||||
in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) of cross-attention to the state dict
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
|
||||
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our DETR structure.
|
||||
"""
|
||||
|
||||
# load default config
|
||||
config, is_panoptic = get_detr_config(model_name)
|
||||
|
||||
# load original model from torch hub
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
|
||||
state_dict = detr.state_dict()
|
||||
# rename keys
|
||||
for src, dest in create_rename_keys(config):
|
||||
if is_panoptic:
|
||||
src = "detr." + src
|
||||
rename_key(state_dict, src, dest)
|
||||
# query, key and value matrices need special treatment
|
||||
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
prefix = "detr.model." if is_panoptic else "model."
|
||||
for key in state_dict.copy().keys():
|
||||
if is_panoptic:
|
||||
if (
|
||||
key.startswith("detr")
|
||||
and not key.startswith("class_labels_classifier")
|
||||
and not key.startswith("bbox_predictor")
|
||||
):
|
||||
val = state_dict.pop(key)
|
||||
state_dict["detr.model" + key[4:]] = val
|
||||
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
||||
val = state_dict.pop(key)
|
||||
state_dict["detr." + key] = val
|
||||
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
||||
continue
|
||||
else:
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
else:
|
||||
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# verify our conversion on an image
|
||||
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
||||
processor = DetrImageProcessor(format=format)
|
||||
|
||||
encoding = processor(images=prepare_img(), return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
original_outputs = detr(pixel_values)
|
||||
outputs = model(pixel_values)
|
||||
|
||||
print("Logits:", outputs.logits[0, :3, :3])
|
||||
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
|
||||
|
||||
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
|
||||
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
|
||||
if is_panoptic:
|
||||
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
# Save model and image processor
|
||||
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_detr import DetrConfig
|
||||
|
||||
|
||||
@ -320,45 +321,56 @@ def replace_batch_norm(m, name=""):
|
||||
replace_batch_norm(ch, n)
|
||||
|
||||
|
||||
class DetrTimmConvEncoder(nn.Module):
|
||||
class DetrConvEncoder(nn.Module):
|
||||
"""
|
||||
Convolutional encoder (backbone) from the timm library.
|
||||
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
||||
|
||||
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
if dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
self.config = config
|
||||
|
||||
requires_backends(self, ["timm"])
|
||||
if config.use_timm_backbone:
|
||||
requires_backends(self, ["timm"])
|
||||
kwargs = {}
|
||||
if config.dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
backbone = create_model(
|
||||
config.backbone,
|
||||
pretrained=config.use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=config.num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
backbone = AutoBackbone.from_config(config.backbone_config)
|
||||
|
||||
backbone = create_model(
|
||||
name,
|
||||
pretrained=use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
self.model = backbone
|
||||
self.intermediate_channel_sizes = self.model.feature_info.channels()
|
||||
self.intermediate_channel_sizes = (
|
||||
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
||||
)
|
||||
|
||||
if "resnet" in name:
|
||||
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
|
||||
if "resnet" in backbone_model_type:
|
||||
for name, parameter in self.model.named_parameters():
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
if config.use_timm_backbone:
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
else:
|
||||
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
||||
# send pixel_values through the model to get list of feature maps
|
||||
features = self.model(pixel_values)
|
||||
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
||||
|
||||
out = []
|
||||
for feature_map in features:
|
||||
@ -1191,9 +1203,7 @@ class DetrModel(DetrPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = DetrTimmConvEncoder(
|
||||
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
|
||||
)
|
||||
backbone = DetrConvEncoder(config)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -44,6 +45,12 @@ class TableTransformerConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
||||
API.
|
||||
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
||||
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
||||
case it will default to `ResNetConfig()`.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
num_queries (`int`, *optional*, defaults to 100):
|
||||
@ -87,13 +94,14 @@ class TableTransformerConfig(PretrainedConfig):
|
||||
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
||||
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
||||
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||
list of all available models, see [this
|
||||
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
|
||||
backbone from the timm package. For a list of all available models, see [this
|
||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
|
||||
dilation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||
`use_timm_backbone` = `True`.
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
Relative weight of the classification error in the Hungarian matching cost.
|
||||
bbox_cost (`float`, *optional*, defaults to 5):
|
||||
@ -135,6 +143,8 @@ class TableTransformerConfig(PretrainedConfig):
|
||||
# Copied from transformers.models.detr.configuration_detr.DetrConfig.__init__
|
||||
def __init__(
|
||||
self,
|
||||
use_timm_backbone=True,
|
||||
backbone_config=None,
|
||||
num_channels=3,
|
||||
num_queries=100,
|
||||
encoder_layers=6,
|
||||
@ -170,6 +180,20 @@ class TableTransformerConfig(PretrainedConfig):
|
||||
eos_coefficient=0.1,
|
||||
**kwargs
|
||||
):
|
||||
if backbone_config is not None and use_timm_backbone:
|
||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||
|
||||
if not use_timm_backbone:
|
||||
if backbone_config is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_config = backbone_config
|
||||
self.num_channels = num_channels
|
||||
self.num_queries = num_queries
|
||||
self.d_model = d_model
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_table_transformer import TableTransformerConfig
|
||||
|
||||
|
||||
@ -255,46 +256,57 @@ def replace_batch_norm(m, name=""):
|
||||
replace_batch_norm(ch, n)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder with Detr->TableTransformer
|
||||
class TableTransformerTimmConvEncoder(nn.Module):
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer
|
||||
class TableTransformerConvEncoder(nn.Module):
|
||||
"""
|
||||
Convolutional encoder (backbone) from the timm library.
|
||||
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
||||
|
||||
nn.BatchNorm2d layers are replaced by TableTransformerFrozenBatchNorm2d as defined above.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
if dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
self.config = config
|
||||
|
||||
requires_backends(self, ["timm"])
|
||||
if config.use_timm_backbone:
|
||||
requires_backends(self, ["timm"])
|
||||
kwargs = {}
|
||||
if config.dilation:
|
||||
kwargs["output_stride"] = 16
|
||||
backbone = create_model(
|
||||
config.backbone,
|
||||
pretrained=config.use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=config.num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
backbone = AutoBackbone.from_config(config.backbone_config)
|
||||
|
||||
backbone = create_model(
|
||||
name,
|
||||
pretrained=use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
self.model = backbone
|
||||
self.intermediate_channel_sizes = self.model.feature_info.channels()
|
||||
self.intermediate_channel_sizes = (
|
||||
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
||||
)
|
||||
|
||||
if "resnet" in name:
|
||||
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
|
||||
if "resnet" in backbone_model_type:
|
||||
for name, parameter in self.model.named_parameters():
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
if config.use_timm_backbone:
|
||||
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
else:
|
||||
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
||||
# send pixel_values through the model to get list of feature maps
|
||||
features = self.model(pixel_values)
|
||||
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
||||
|
||||
out = []
|
||||
for feature_map in features:
|
||||
@ -1136,9 +1148,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = TableTransformerTimmConvEncoder(
|
||||
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
|
||||
)
|
||||
backbone = TableTransformerConvEncoder(config)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = TableTransformerConvModel(backbone, position_embeddings)
|
||||
|
||||
|
@ -31,7 +31,12 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
|
||||
if is_timm_available():
|
||||
import torch
|
||||
|
||||
from transformers import ConditionalDetrForObjectDetection, ConditionalDetrForSegmentation, ConditionalDetrModel
|
||||
from transformers import (
|
||||
ConditionalDetrForObjectDetection,
|
||||
ConditionalDetrForSegmentation,
|
||||
ConditionalDetrModel,
|
||||
ResNetConfig,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -153,6 +158,25 @@ class ConditionalDetrModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
config.use_timm_backbone = False
|
||||
config.backbone_config = ResNetConfig()
|
||||
model = ConditionalDetrForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_timm
|
||||
class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -213,6 +237,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
def test_conditional_detr_no_timm_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
|
||||
if is_timm_available():
|
||||
import torch
|
||||
|
||||
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel
|
||||
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel, ResNetConfig
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -164,6 +164,25 @@ class DeformableDetrModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
config.use_timm_backbone = False
|
||||
config.backbone_config = ResNetConfig()
|
||||
model = DeformableDetrForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_timm
|
||||
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -221,6 +240,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
def test_deformable_detr_no_timm_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
|
||||
if is_timm_available():
|
||||
import torch
|
||||
|
||||
from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel
|
||||
from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel, ResNetConfig
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -153,6 +153,25 @@ class DetrModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
config.use_timm_backbone = False
|
||||
config.backbone_config = ResNetConfig()
|
||||
model = DetrForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_timm
|
||||
class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -213,6 +232,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
def test_detr_no_timm_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
|
||||
if is_timm_available():
|
||||
import torch
|
||||
|
||||
from transformers import TableTransformerForObjectDetection, TableTransformerModel
|
||||
from transformers import ResNetConfig, TableTransformerForObjectDetection, TableTransformerModel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -153,6 +153,25 @@ class TableTransformerModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
def create_and_check_table_transformer_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
config.use_timm_backbone = False
|
||||
config.backbone_config = ResNetConfig()
|
||||
model = TableTransformerForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_timm
|
||||
class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -212,6 +231,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_table_transformer_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
def test_table_transformer_no_timm_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_table_transformer_no_timm_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user