mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[SegFormer] Add support for segmentation masks with one label (#20279)
* Add support for binary segmentation * Fix loss calculation and add test * Remove space * use fstring Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
parent
2280880cb7
commit
2875fa971c
@ -806,15 +806,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if not self.config.num_labels > 1:
|
# upsample logits to the images' original size
|
||||||
raise ValueError("The number of labels should be greater than one")
|
upsampled_logits = nn.functional.interpolate(
|
||||||
else:
|
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
# upsample logits to the images' original size
|
)
|
||||||
upsampled_logits = nn.functional.interpolate(
|
if self.config.num_labels > 1:
|
||||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
||||||
loss = loss_fct(upsampled_logits, labels)
|
loss = loss_fct(upsampled_logits, labels)
|
||||||
|
elif self.config.num_labels == 1:
|
||||||
|
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
|
||||||
|
loss_fct = BCEWithLogitsLoss(reduction="none")
|
||||||
|
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
|
||||||
|
loss = (loss * valid_mask).mean()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -140,6 +140,16 @@ class SegformerModelTester:
|
|||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||||
)
|
)
|
||||||
|
self.parent.assertGreater(result.loss, 0.0)
|
||||||
|
|
||||||
|
def create_and_check_for_binary_image_segmentation(self, config, pixel_values, labels):
|
||||||
|
config.num_labels = 1
|
||||||
|
model = SegformerForSemanticSegmentation(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
labels = torch.randint(0, 1, (self.batch_size, self.image_size, self.image_size)).to(torch_device)
|
||||||
|
result = model(pixel_values, labels=labels)
|
||||||
|
self.parent.assertGreater(result.loss, 0.0)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@ -177,6 +187,10 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_binary_image_segmentation(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_binary_image_segmentation(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_image_segmentation(self):
|
def test_for_image_segmentation(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user