mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Handle empty change indices in SAM's mask to rle conversion (#35665)
* Handle empty change indices in RLE conversion for masks * [test] Add unit tests for RLE encoding of masks in SamProcessor * [test] Update RLE conversion tests to use TensorFlow implementation * [test] Fix formatting in SamProcessorTest according to check_code_quality action * [test] Fix formatting in SamProcessorTest according to check_code_quality * [test] Refactored rle test cases into one test and used tf tensors in tf test cases * [test] Fix: removed self parameter from refactored methods * [test] Removed nested methods in run-length encoding tests for PyTorch and TensorFlow * [test] Added description to individual to run-length encoding tests for PyTorch and TensorFlow.
This commit is contained in:
parent
47bd4296d6
commit
e4227eb4d4
@ -1373,6 +1373,14 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
if len(cur_idxs) == 0:
|
||||
# No changes => either all 0 or all 1
|
||||
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
||||
if input_mask[i, 0] == 0:
|
||||
out.append({"size": [height, width], "counts": [height * width]})
|
||||
else:
|
||||
out.append({"size": [height, width], "counts": [0, height * width]})
|
||||
continue
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
@ -1396,6 +1404,14 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
if len(cur_idxs) == 0:
|
||||
# No changes => either all 0 or all 1
|
||||
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
||||
if input_mask[i, 0] == 0:
|
||||
out.append({"size": [height, width], "counts": [height * width]})
|
||||
else:
|
||||
out.append({"size": [height, width], "counts": [0, height * width]})
|
||||
continue
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
|
@ -37,9 +37,13 @@ if is_vision_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.sam.image_processing_sam import _mask_to_rle_pytorch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.models.sam.image_processing_sam import _mask_to_rle_tf
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
@ -161,6 +165,42 @@ class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
|
||||
|
||||
def test_rle_encoding(self):
|
||||
"""
|
||||
Test the run-length encoding function.
|
||||
"""
|
||||
# Test that a mask of all zeros returns a single run [height * width].
|
||||
input_mask = torch.zeros((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2
|
||||
rle = _mask_to_rle_pytorch(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
# For a 2x2 all-zero mask, we expect a single run of length 4:
|
||||
self.assertEqual(rle[0]["counts"], [4])
|
||||
|
||||
# Test that a mask of all ones returns [0, height * width].
|
||||
input_mask = torch.ones((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2
|
||||
rle = _mask_to_rle_pytorch(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
# For a 2x2 all-one mask, we expect two runs: [0, 4].
|
||||
self.assertEqual(rle[0]["counts"], [0, 4])
|
||||
|
||||
# Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct.
|
||||
# Example mask:
|
||||
# Row 0: [0, 1]
|
||||
# Row 1: [1, 1]
|
||||
# This is shape (1, 2, 2).
|
||||
# Flattened in Fortran order -> [0, 1, 1, 1].
|
||||
# The RLE for [0,1,1,1] is [1, 3].
|
||||
input_mask = torch.tensor([[[0, 1], [1, 1]]], dtype=torch.long)
|
||||
rle = _mask_to_rle_pytorch(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_tf
|
||||
@ -244,6 +284,42 @@ class TFSamProcessorTest(unittest.TestCase):
|
||||
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
|
||||
)
|
||||
|
||||
def test_rle_encoding(self):
|
||||
"""
|
||||
Test the run-length encoding function.
|
||||
"""
|
||||
# Test that a mask of all zeros returns a single run [height * width].
|
||||
input_mask = tf.zeros((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2
|
||||
rle = _mask_to_rle_tf(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
# For a 2x2 all-zero mask, we expect a single run of length 4:
|
||||
self.assertEqual(rle[0]["counts"], [4])
|
||||
|
||||
# Test that a mask of all ones returns [0, height * width].
|
||||
input_mask = tf.ones((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2
|
||||
rle = _mask_to_rle_tf(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
# For a 2x2 all-one mask, we expect two runs: [0, 4].
|
||||
self.assertEqual(rle[0]["counts"], [0, 4])
|
||||
|
||||
# Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct.
|
||||
# Example mask:
|
||||
# Row 0: [0, 1]
|
||||
# Row 1: [1, 1]
|
||||
# This is shape (1, 2, 2).
|
||||
# Flattened in Fortran order -> [0, 1, 1, 1].
|
||||
# The RLE for [0,1,1,1] is [1, 3].
|
||||
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
|
||||
rle = _mask_to_rle_tf(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
self.assertEqual(rle[0]["size"], [2, 2])
|
||||
self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
|
Loading…
Reference in New Issue
Block a user