[DETR] Add num_channels attribute (#18714)

* Add num_channels attribute

* Fix code quality

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-08-31 18:04:42 +02:00 committed by GitHub
parent 811c4c9f79
commit 3b6943e7a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 4 deletions

View File

@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
detect in a single image. For COCO, we recommend 100 queries.
@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig):
def __init__(
self,
num_channels=3,
num_queries=100,
max_position_embeddings=1024,
encoder_layers=6,
@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs
):
self.num_channels = num_channels
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model

View File

@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
"""
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
super().__init__()
kwargs = {}
@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module):
requires_backends(self, ["timm"])
backbone = create_model(
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
)
# replace batch norm by frozen batch norm
with torch.no_grad():
@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
backbone = DetrTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings)

View File

@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.assertTrue(outputs)
def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# use greyscale pixel values
inputs_dict["pixel_values"] = floats_tensor(
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
)
# let's set num_channels to 1
config.num_channels = 1
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertTrue(outputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()