mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[DETR] Add num_channels attribute (#18714)
* Add num_channels attribute * Fix code quality Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
811c4c9f79
commit
3b6943e7a3
@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig):
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
num_queries (`int`, *optional*, defaults to 100):
|
||||
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
||||
detect in a single image. For COCO, we recommend 100 queries.
|
||||
@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels=3,
|
||||
num_queries=100,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=6,
|
||||
@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig):
|
||||
eos_coefficient=0.1,
|
||||
**kwargs
|
||||
):
|
||||
self.num_channels = num_channels
|
||||
self.num_queries = num_queries
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
|
@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module):
|
||||
requires_backends(self, ["timm"])
|
||||
|
||||
backbone = create_model(
|
||||
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
|
||||
name,
|
||||
pretrained=use_pretrained_backbone,
|
||||
features_only=True,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
in_chans=num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
|
||||
backbone = DetrTimmConvEncoder(
|
||||
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
|
||||
)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
||||
|
||||
|
@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_greyscale_images(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# use greyscale pixel values
|
||||
inputs_dict["pixel_values"] = floats_tensor(
|
||||
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
|
||||
)
|
||||
|
||||
# let's set num_channels to 1
|
||||
config.num_channels = 1
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user