mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[DETA] fix backbone freeze/unfreeze function (#27843)
* [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
df5c5c62ae
commit
235be08569
@ -1414,14 +1414,12 @@ class DetaModel(DetaPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
|
||||
def freeze_backbone(self):
|
||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
||||
for name, param in self.backbone.model.named_parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
|
||||
def unfreeze_backbone(self):
|
||||
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
||||
for name, param in self.backbone.model.named_parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
|
||||
|
@ -162,6 +162,26 @@ class DetaModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
|
||||
|
||||
def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
model = DetaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.freeze_backbone()
|
||||
|
||||
for _, param in model.backbone.model.named_parameters():
|
||||
self.parent.assertEqual(False, param.requires_grad)
|
||||
|
||||
def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
model = DetaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.unfreeze_backbone()
|
||||
|
||||
for _, param in model.backbone.model.named_parameters():
|
||||
self.parent.assertEqual(True, param.requires_grad)
|
||||
|
||||
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
|
||||
model = DetaForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
@ -250,6 +270,14 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deta_model(*config_and_inputs)
|
||||
|
||||
def test_deta_freeze_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
|
||||
|
||||
def test_deta_unfreeze_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
|
||||
|
||||
def test_deta_object_detection_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user