[DETR and friends] Use AutoBackbone as alternative to timm (#20833)

* First draft

* More improvements

* Add conversion script

* More improvements

* Add docs

* Address review

* Rename class to ConvEncoder

* Address review

* Apply suggestion

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update all DETR friends

* Add corresponding test

* Improve test

* Fix bug

* Add more tests

* Set out_features to last stage by default

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
NielsRogge 2023-01-23 12:15:47 +01:00 committed by GitHub
parent c8d719ff7e
commit 91ff7efeeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 731 additions and 113 deletions

View File

@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -44,6 +45,12 @@ class ConditionalDetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100):
@ -87,13 +94,14 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5):
@ -136,6 +144,8 @@ class ConditionalDetrConfig(PretrainedConfig):
def __init__(
self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3,
num_queries=300,
encoder_layers=6,
@ -172,6 +182,20 @@ class ConditionalDetrConfig(PretrainedConfig):
focal_alpha=0.25,
**kwargs
):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self.num_queries = num_queries
self.d_model = d_model

View File

@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from .configuration_conditional_detr import ConditionalDetrConfig
@ -326,46 +327,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n)
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder
class ConditionalDetrTimmConvEncoder(nn.Module):
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
class ConditionalDetrConvEncoder(nn.Module):
"""
Convolutional encoder (backbone) from the timm library.
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
"""
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
def __init__(self, config):
super().__init__()
kwargs = {}
if dilation:
kwargs["output_stride"] = 16
self.config = config
requires_backends(self, ["timm"])
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = create_model(
name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels()
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name:
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = []
for feature_map in features:
@ -1468,9 +1480,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = ConditionalDetrTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
backbone = ConditionalDetrConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = ConditionalDetrConvModel(backbone, position_embeddings)

View File

@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -37,6 +38,14 @@ class DeformableDetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 300):
Number of object queries, i.e. detection slots. This is the maximal number of objects
[`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use
@ -79,11 +88,14 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5):
@ -139,6 +151,9 @@ class DeformableDetrConfig(PretrainedConfig):
def __init__(
self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3,
num_queries=300,
max_position_embeddings=1024,
encoder_layers=6,
@ -161,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
auxiliary_loss=False,
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
dilation=False,
num_feature_levels=4,
encoder_n_points=4,
@ -179,6 +195,20 @@ class DeformableDetrConfig(PretrainedConfig):
focal_alpha=0.25,
**kwargs
):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
@ -199,6 +229,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.auxiliary_loss = auxiliary_loss
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.dilation = dilation
# deformable attributes
self.num_feature_levels = num_feature_levels

View File

@ -43,6 +43,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
@ -371,45 +372,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n)
class DeformableDetrTimmConvEncoder(nn.Module):
class DeformableDetrConvEncoder(nn.Module):
"""
Convolutional encoder (backbone) from the timm library.
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above.
"""
def __init__(self, config):
super().__init__()
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
self.config = config
requires_backends(self, ["timm"])
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
out_indices = (2, 3, 4) if config.num_feature_levels > 1 else (4,)
backbone = create_model(
config.backbone, pretrained=True, features_only=True, out_indices=out_indices, **kwargs
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels()
self.strides = self.model.feature_info.reduction()
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in config.backbone:
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
"""
Outputs feature maps of latter stages C_3 through C_5 in ResNet if `config.num_feature_levels > 1`, otherwise
outputs feature maps of C_5.
"""
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = []
for feature_map in features:
@ -1438,13 +1451,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = DeformableDetrTimmConvEncoder(config)
backbone = DeformableDetrConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
# Create input projection layers
if config.num_feature_levels > 1:
num_backbone_outs = len(backbone.strides)
num_backbone_outs = len(backbone.intermediate_channel_sizes)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.intermediate_channel_sizes[_]

View File

@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -43,6 +44,12 @@ class DetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100):
@ -86,13 +93,14 @@ class DetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5):
@ -133,6 +141,8 @@ class DetrConfig(PretrainedConfig):
def __init__(
self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3,
num_queries=100,
encoder_layers=6,
@ -168,6 +178,20 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs
):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self.num_queries = num_queries
self.d_model = d_model

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DETR checkpoints."""
"""Convert DETR checkpoints with timm backbone."""
import argparse

View File

@ -0,0 +1,375 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DETR checkpoints with native (Transformers) backbone."""
import argparse
import json
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor, ResNetConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_detr_config(model_name):
config = DetrConfig(use_timm_backbone=False)
# set backbone attributes
if "resnet50" in model_name:
pass
elif "resnet101" in model_name:
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
else:
raise ValueError("Model name should include either resnet50 or resnet101")
# set label attributes
is_panoptic = "panoptic" in model_name
if is_panoptic:
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config, is_panoptic
def create_rename_keys(config):
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight"))
rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight"))
rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias"))
rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean"))
rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var"))
# stages
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
# shortcut
if layer_idx == 0:
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
)
)
# 3 convs
for i in range(3):
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
)
)
# fmt: on
for i in range(config.encoder_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append(
(
f"transformer.encoder.layers.{i}.self_attn.out_proj.weight",
f"encoder.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")
)
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(
f"transformer.decoder.layers.{i}.self_attn.out_proj.weight",
f"decoder.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
rename_keys.extend(
[
("input_proj.weight", "input_projection.weight"),
("input_proj.bias", "input_projection.bias"),
("query_embed.weight", "query_position_embeddings.weight"),
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
("class_embed.weight", "class_labels_classifier.weight"),
("class_embed.bias", "class_labels_classifier.bias"),
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
]
)
return rename_keys
def rename_key(state_dict, old, new):
val = state_dict.pop(old)
state_dict[new] = val
def read_in_q_k_v(state_dict, is_panoptic=False):
prefix = ""
if is_panoptic:
prefix = "detr."
# first: transformer encoder
for i in range(6):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
for i in range(6):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# read in weights + bias of input projection layer of cross-attention
in_proj_weight_cross_attn = state_dict.pop(
f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
)
in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) of cross-attention to the state dict
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our DETR structure.
"""
# load default config
config, is_panoptic = get_detr_config(model_name)
# load original model from torch hub
logger.info(f"Converting model {model_name}...")
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
state_dict = detr.state_dict()
# rename keys
for src, dest in create_rename_keys(config):
if is_panoptic:
src = "detr." + src
rename_key(state_dict, src, dest)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
prefix = "detr.model." if is_panoptic else "model."
for key in state_dict.copy().keys():
if is_panoptic:
if (
key.startswith("detr")
and not key.startswith("class_labels_classifier")
and not key.startswith("bbox_predictor")
):
val = state_dict.pop(key)
state_dict["detr.model" + key[4:]] = val
elif "class_labels_classifier" in key or "bbox_predictor" in key:
val = state_dict.pop(key)
state_dict["detr." + key] = val
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
continue
else:
val = state_dict.pop(key)
state_dict[prefix + key] = val
else:
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
val = state_dict.pop(key)
state_dict[prefix + key] = val
# finally, create HuggingFace model and load state dict
model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
# verify our conversion on an image
format = "coco_panoptic" if is_panoptic else "coco_detection"
processor = DetrImageProcessor(format=format)
encoding = processor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
original_outputs = detr(pixel_values)
outputs = model(pixel_values)
print("Logits:", outputs.logits[0, :3, :3])
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
if is_panoptic:
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
# Save model and image processor
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)

View File

@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from .configuration_detr import DetrConfig
@ -320,45 +321,56 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n)
class DetrTimmConvEncoder(nn.Module):
class DetrConvEncoder(nn.Module):
"""
Convolutional encoder (backbone) from the timm library.
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
"""
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
def __init__(self, config):
super().__init__()
kwargs = {}
if dilation:
kwargs["output_stride"] = 16
self.config = config
requires_backends(self, ["timm"])
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = create_model(
name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels()
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name:
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = []
for feature_map in features:
@ -1191,9 +1203,7 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = DetrTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
backbone = DetrConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings)

View File

@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
@ -44,6 +45,12 @@ class TableTransformerConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100):
@ -87,13 +94,14 @@ class TableTransformerConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5):
@ -135,6 +143,8 @@ class TableTransformerConfig(PretrainedConfig):
# Copied from transformers.models.detr.configuration_detr.DetrConfig.__init__
def __init__(
self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3,
num_queries=100,
encoder_layers=6,
@ -170,6 +180,20 @@ class TableTransformerConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs
):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self.num_queries = num_queries
self.d_model = d_model

View File

@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from .configuration_table_transformer import TableTransformerConfig
@ -255,46 +256,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n)
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder with Detr->TableTransformer
class TableTransformerTimmConvEncoder(nn.Module):
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer
class TableTransformerConvEncoder(nn.Module):
"""
Convolutional encoder (backbone) from the timm library.
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by TableTransformerFrozenBatchNorm2d as defined above.
"""
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
def __init__(self, config):
super().__init__()
kwargs = {}
if dilation:
kwargs["output_stride"] = 16
self.config = config
requires_backends(self, ["timm"])
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = create_model(
name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels()
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name:
backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = []
for feature_map in features:
@ -1136,9 +1148,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = TableTransformerTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
backbone = TableTransformerConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = TableTransformerConvModel(backbone, position_embeddings)

View File

@ -31,7 +31,12 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available():
import torch
from transformers import ConditionalDetrForObjectDetection, ConditionalDetrForSegmentation, ConditionalDetrModel
from transformers import (
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ResNetConfig,
)
if is_vision_available():
@ -153,6 +158,25 @@ class ConditionalDetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = ConditionalDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm
class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -213,6 +237,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
def test_conditional_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass

View File

@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available():
import torch
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel, ResNetConfig
if is_vision_available():
@ -164,6 +164,25 @@ class DeformableDetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = DeformableDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -221,6 +240,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs)
def test_deformable_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass

View File

@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available():
import torch
from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel
from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel, ResNetConfig
if is_vision_available():
@ -153,6 +153,25 @@ class DetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = DetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm
class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -213,6 +232,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_detr_object_detection_head_model(*config_and_inputs)
def test_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass

View File

@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available():
import torch
from transformers import TableTransformerForObjectDetection, TableTransformerModel
from transformers import ResNetConfig, TableTransformerForObjectDetection, TableTransformerModel
if is_vision_available():
@ -153,6 +153,25 @@ class TableTransformerModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_table_transformer_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = TableTransformerForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm
class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -212,6 +231,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_table_transformer_object_detection_head_model(*config_and_inputs)
def test_table_transformer_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_table_transformer_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
def test_inputs_embeds(self):
pass