mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
RT-DETR parameterized batchnorm freezing (#32631)
* fix: Parameterized norm freezing For the R18 model, the authors don't freeze norms in the backbone. * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
8a4857c0db
commit
5f6c080b62
@ -55,6 +55,8 @@ class RTDetrConfig(PretrainedConfig):
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||
library.
|
||||
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
||||
Whether to freeze the batch normalization layers in the backbone.
|
||||
backbone_kwargs (`dict`, *optional*):
|
||||
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||
@ -190,6 +192,7 @@ class RTDetrConfig(PretrainedConfig):
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
use_timm_backbone=False,
|
||||
freeze_backbone_batch_norms=True,
|
||||
backbone_kwargs=None,
|
||||
# encoder HybridEncoder
|
||||
encoder_hidden_dim=256,
|
||||
@ -280,6 +283,7 @@ class RTDetrConfig(PretrainedConfig):
|
||||
self.backbone = backbone
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
||||
self.backbone_kwargs = backbone_kwargs
|
||||
# encoder
|
||||
self.encoder_hidden_dim = encoder_hidden_dim
|
||||
|
@ -559,9 +559,10 @@ class RTDetrConvEncoder(nn.Module):
|
||||
|
||||
backbone = load_backbone(config)
|
||||
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
if config.freeze_backbone_batch_norms:
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
self.model = backbone
|
||||
self.intermediate_channel_sizes = self.model.channels
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user