[SegFormer] Add support for segmentation masks with one label (#20279)

* Add support for binary segmentation

* Fix loss calculation and add test

* Remove space

* use fstring

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
NielsRogge 2022-12-20 16:46:50 +01:00 committed by GitHub
parent 2280880cb7
commit 2875fa971c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 7 deletions

View File

@ -806,15 +806,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if not self.config.num_labels > 1: # upsample logits to the images' original size
raise ValueError("The number of labels should be greater than one") upsampled_logits = nn.functional.interpolate(
else: logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
# upsample logits to the images' original size )
upsampled_logits = nn.functional.interpolate( if self.config.num_labels > 1:
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels) loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
loss_fct = BCEWithLogitsLoss(reduction="none")
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
loss = (loss * valid_mask).mean()
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:

View File

@ -140,6 +140,16 @@ class SegformerModelTester:
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
) )
self.parent.assertGreater(result.loss, 0.0)
def create_and_check_for_binary_image_segmentation(self, config, pixel_values, labels):
config.num_labels = 1
model = SegformerForSemanticSegmentation(config=config)
model.to(torch_device)
model.eval()
labels = torch.randint(0, 1, (self.batch_size, self.image_size, self.image_size)).to(torch_device)
result = model(pixel_values, labels=labels)
self.parent.assertGreater(result.loss, 0.0)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
@ -177,6 +187,10 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_binary_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_binary_image_segmentation(*config_and_inputs)
def test_for_image_segmentation(self): def test_for_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs) self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)