Check for case where auxiliary_head is None in UperNetPreTrainedModel (#25514)

check for case where auxiliary_head is None in UperNetPreTrainedModel
This commit is contained in:
Michael Murray 2023-08-14 23:44:21 -07:00 committed by GitHub
parent b42010bb1d
commit df91ff5314
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -305,13 +305,15 @@ class UperNetPreTrainedModel(PreTrainedModel):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
module.auxiliary_head.init_weights()
if module.auxiliary_head is not None:
module.auxiliary_head.init_weights()
def init_weights(self):
"""Initialize the weights"""
self.backbone.init_weights()
self.decode_head.init_weights()
self.auxiliary_head.init_weights()
if self.auxiliary_head is not None:
self.auxiliary_head.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BackboneMixin):
@ -429,9 +431,10 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
else:
# compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
main_loss = loss_fct(logits, labels)
auxiliary_loss = loss_fct(auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
loss = loss_fct(logits, labels)
if auxiliary_logits is not None:
auxiliary_loss = loss_fct(auxiliary_logits, labels)
loss += self.config.auxiliary_loss_weight * auxiliary_loss
if not return_dict:
if output_hidden_states: