[DETA] fix backbone freeze/unfreeze function (#27843)

* [DETA] fix freeze/unfreeze function

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add freeze/unfreeze test case in DETA

* fix type

* fix typo 2

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sangbum Daniel Choi 2023-12-11 15:57:30 +09:00 committed by GitHub
parent df5c5c62ae
commit 235be08569
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 4 deletions

View File

@ -1414,14 +1414,12 @@ class DetaModel(DetaPreTrainedModel):
def get_decoder(self):
return self.decoder
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(False)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(True)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio

View File

@ -162,6 +162,26 @@ class DetaModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()
model.freeze_backbone()
for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(False, param.requires_grad)
def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()
model.unfreeze_backbone()
for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(True, param.requires_grad)
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DetaForObjectDetection(config=config)
model.to(torch_device)
@ -250,6 +270,14 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_model(*config_and_inputs)
def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)