RT-DETR parameterized batchnorm freezing (#32631)

* fix: Parameterized norm freezing

For the R18 model, the authors don't freeze norms in the backbone.

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Alan-Blanchet 2024-08-19 15:50:57 +02:00 committed by GitHub
parent 8a4857c0db
commit 5f6c080b62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View File

@ -55,6 +55,8 @@ class RTDetrConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
Whether to freeze the batch normalization layers in the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
@ -190,6 +192,7 @@ class RTDetrConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
freeze_backbone_batch_norms=True,
backbone_kwargs=None,
# encoder HybridEncoder
encoder_hidden_dim=256,
@ -280,6 +283,7 @@ class RTDetrConfig(PretrainedConfig):
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
self.backbone_kwargs = backbone_kwargs
# encoder
self.encoder_hidden_dim = encoder_hidden_dim

View File

@ -559,9 +559,10 @@ class RTDetrConvEncoder(nn.Module):
backbone = load_backbone(config)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
if config.freeze_backbone_batch_norms:
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.channels