mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[SegFormer] Add support for segmentation masks with one label (#20279)
* Add support for binary segmentation * Fix loss calculation and add test * Remove space * use fstring Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
parent
2280880cb7
commit
2875fa971c
@ -806,15 +806,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if not self.config.num_labels > 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
# upsample logits to the images' original size
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
# upsample logits to the images' original size
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
if self.config.num_labels > 1:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
||||
loss = loss_fct(upsampled_logits, labels)
|
||||
elif self.config.num_labels == 1:
|
||||
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
|
||||
loss_fct = BCEWithLogitsLoss(reduction="none")
|
||||
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
|
||||
loss = (loss * valid_mask).mean()
|
||||
else:
|
||||
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
|
@ -140,6 +140,16 @@ class SegformerModelTester:
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||
)
|
||||
self.parent.assertGreater(result.loss, 0.0)
|
||||
|
||||
def create_and_check_for_binary_image_segmentation(self, config, pixel_values, labels):
|
||||
config.num_labels = 1
|
||||
model = SegformerForSemanticSegmentation(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
labels = torch.randint(0, 1, (self.batch_size, self.image_size, self.image_size)).to(torch_device)
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertGreater(result.loss, 0.0)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@ -177,6 +187,10 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_binary_image_segmentation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_binary_image_segmentation(*config_and_inputs)
|
||||
|
||||
def test_for_image_segmentation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user