diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index fb320afdd20..604a7dad0f4 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig): Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. num_queries (`int`, *optional*, defaults to 100): Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can detect in a single image. For COCO, we recommend 100 queries. @@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig): def __init__( self, + num_channels=3, num_queries=100, max_position_embeddings=1024, encoder_layers=6, @@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig): eos_coefficient=0.1, **kwargs ): + self.num_channels = num_channels self.num_queries = num_queries self.max_position_embeddings = max_position_embeddings self.d_model = d_model diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index deaab873fc4..9d974f1a551 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module): """ - def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool): + def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3): super().__init__() kwargs = {} @@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module): requires_backends(self, ["timm"]) backbone = create_model( - name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs + name, + pretrained=use_pretrained_backbone, + features_only=True, + out_indices=(1, 2, 3, 4), + in_chans=num_channels, + **kwargs, ) # replace batch norm by frozen batch norm with torch.no_grad(): @@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel): super().__init__(config) # Create backbone + positional encoding - backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone) + backbone = DetrTimmConvEncoder( + config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels + ) position_embeddings = build_position_encoding(config) self.backbone = DetrConvModel(backbone, position_embeddings) diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index 7b0b7eeb754..d64c6a062e7 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): self.assertTrue(outputs) + def test_greyscale_images(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # use greyscale pixel values + inputs_dict["pixel_values"] = floats_tensor( + [self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size] + ) + + # let's set num_channels to 1 + config.num_channels = 1 + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + self.assertTrue(outputs) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()