mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Check for case where auxiliary_head
is None
in UperNetPreTrainedModel
(#25514)
check for case where auxiliary_head is None in UperNetPreTrainedModel
This commit is contained in:
parent
b42010bb1d
commit
df91ff5314
@ -305,13 +305,15 @@ class UperNetPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, UperNetPreTrainedModel):
|
||||
module.backbone.init_weights()
|
||||
module.decode_head.init_weights()
|
||||
module.auxiliary_head.init_weights()
|
||||
if module.auxiliary_head is not None:
|
||||
module.auxiliary_head.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights"""
|
||||
self.backbone.init_weights()
|
||||
self.decode_head.init_weights()
|
||||
self.auxiliary_head.init_weights()
|
||||
if self.auxiliary_head is not None:
|
||||
self.auxiliary_head.init_weights()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BackboneMixin):
|
||||
@ -429,9 +431,10 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
|
||||
else:
|
||||
# compute weighted loss
|
||||
loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
|
||||
main_loss = loss_fct(logits, labels)
|
||||
auxiliary_loss = loss_fct(auxiliary_logits, labels)
|
||||
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
||||
loss = loss_fct(logits, labels)
|
||||
if auxiliary_logits is not None:
|
||||
auxiliary_loss = loss_fct(auxiliary_logits, labels)
|
||||
loss += self.config.auxiliary_loss_weight * auxiliary_loss
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
|
Loading…
Reference in New Issue
Block a user