Feature Extractor accepts segmentation_maps (#15964)

* feature extractor accepts

* resolved conversations

* added examples in test for ADE20K

* num_classes -> num_labels

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* resolving conversations

* resolving conversations

* removed ADE

* CI

* minor changes in conversion script

* reduce_labels in feature extractor

* minor changes

* correct preprocess for instace segmentation maps

* minor changes

* minor changes

* CI

* debugging

* better padding

* going to update labels inside the model

* going to update labels inside the model

* minor changes

* tests

* removed changes in feature_extractor_utils

* conversation

* conversation

* example in feature extractor

* more docstring in modeling

* test

* make style

* doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Francesco Saverio Zuppichini 2022-03-30 18:46:51 +02:00 committed by GitHub
parent c2f8eaf6bc
commit c4deb7b3ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 328 additions and 193 deletions

View File

@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
model = original_config.MODEL
model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
return MaskFormerFeatureExtractor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
size_divisibility=32, # 32 is required by swin
)
@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter:
yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation):
def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor):
with torch.no_grad():
original_model = original_model.eval()
@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation):
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
feature_extractor = MaskFormerFeatureExtractor()
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose(
@ -707,7 +708,7 @@ if __name__ == "__main__":
mask_former_for_instance_segmentation
)
test(original_model, mask_former_for_instance_segmentation)
test(original_model, mask_former_for_instance_segmentation, feature_extractor)
model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}")

View File

@ -54,6 +54,10 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
max_size (`int`, *optional*, defaults to 1333):
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
if `do_resize` is set to `True`.
size_divisibility (`int`, *optional*, defaults to 32):
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
Swin Transformer.
@ -64,8 +68,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*, default to 255):
Value of the index (label) to ignore.
ignore_index (`int`, *optional*):
Value of the index (label) to be removed from the segmentation maps.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by `ignore_index`.
"""
@ -76,24 +84,28 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
do_resize=True,
size=800,
max_size=1333,
resample=Image.BILINEAR,
size_divisibility=32,
do_normalize=True,
image_mean=None,
image_std=None,
ignore_index=255,
ignore_index=None,
reduce_labels=False,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.max_size = max_size
self.resample = resample
self.size_divisibility = size_divisibility
self.ignore_index = ignore_index
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
self.ignore_index = ignore_index
self.reduce_labels = reduce_labels
def _resize(self, image, size, target=None, max_size=None):
def _resize_with_size_divisibility(self, image, size, target=None, max_size=None):
"""
Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int,
smaller edge of the image will be matched to this number.
@ -138,30 +150,19 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility
size = (width, height)
rescaled_image = self.resize(image, size=size)
image = self.resize(image, size=size, resample=self.resample)
has_target = target is not None
if target is not None:
target = self.resize(target, size=size, resample=Image.NEAREST)
if has_target:
target = target.copy()
# store original_size
target["original_size"] = image.size
if "masks" in target:
masks = torch.from_numpy(target["masks"])[:, None].float()
# use PyTorch as current workaround
# TODO replace by self.resize
interpolated_masks = (
nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5
).float()
target["masks"] = interpolated_masks.numpy()
return rescaled_image, target
return image, target
def __call__(
self,
images: ImageInput,
annotations: Union[List[Dict], List[List[Dict]]] = None,
segmentation_maps: ImageInput = None,
pad_and_return_pixel_mask: Optional[bool] = True,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
@ -170,6 +171,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are
real/which are padding.
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
@ -183,10 +190,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
annotations (`Dict`, `List[Dict]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys:
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
@ -196,7 +201,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
dictionary mapping instance ids to label ids to create a semantic segmentation map.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
@ -206,15 +216,16 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
model (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
`annotations` are provided).
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
(when `annotations` are provided).
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
"""
# Input type checking for clearer error
valid_images = False
valid_annotations = False
valid_segmentation_maps = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
@ -228,6 +239,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
# Check that segmentation maps has a valid type
if segmentation_maps is not None:
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
valid_segmentation_maps = True
elif isinstance(segmentation_maps, (list, tuple)):
if (
len(segmentation_maps) == 0
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
or is_torch_tensor(segmentation_maps[0])
):
valid_segmentation_maps = True
if not valid_segmentation_maps:
raise ValueError(
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
@ -236,35 +264,33 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
if not is_batched:
images = [images]
if annotations is not None:
annotations = [annotations]
# Check that annotations has a valid type
if annotations is not None:
valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0]
if not valid_annotations:
raise ValueError(
"Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)."
"The annotations must be numpy arrays in the following format:"
"{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}"
)
if segmentation_maps is not None:
segmentation_maps = [segmentation_maps]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
if annotations is not None:
for idx, (image, target) in enumerate(zip(images, annotations)):
image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size)
if segmentation_maps is not None:
for idx, (image, target) in enumerate(zip(images, segmentation_maps)):
image, target = self._resize_with_size_divisibility(
image=image, target=target, size=self.size, max_size=self.max_size
)
images[idx] = image
annotations[idx] = target
segmentation_maps[idx] = target
else:
for idx, image in enumerate(images):
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]
images[idx] = self._resize_with_size_divisibility(
image=image, target=None, size=self.size, max_size=self.max_size
)[0]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# NOTE I will be always forced to pad them them since they have to be stacked in the batch dim
encoded_inputs = self.encode_inputs(
images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors
images,
segmentation_maps,
pad_and_return_pixel_mask,
instance_id_to_semantic_id=instance_id_to_semantic_id,
return_tensors=return_tensors,
)
# Convert to TensorType
@ -287,25 +313,57 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
maxes[index] = max(maxes[index], item)
return maxes
def convert_segmentation_map_to_binary_masks(
self,
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
):
if self.reduce_labels:
if self.ignore_index is None:
raise ValueError("`ignore_index` must be set when `reduce_labels` is `True`.")
segmentation_map[segmentation_map == 0] = self.ignore_index
# instances ids start from 1!
segmentation_map -= 1
segmentation_map[segmentation_map == self.ignore_index - 1] = self.ignore_index
if instance_id_to_semantic_id is not None:
# segmentation_map will be treated as an instance segmentation map where each pixel is a instance id
# thus it has to be converted to a semantic segmentation map
for instance_id, label_id in instance_id_to_semantic_id.items():
segmentation_map[segmentation_map == instance_id] = label_id
# get all the labels in the image
labels = np.unique(segmentation_map)
# remove ignore index (if we have one)
if self.ignore_index is not None:
labels = labels[labels != self.ignore_index]
# helping broadcast by making mask [1,W,H] and labels [C, 1, 1]
binary_masks = segmentation_map[None] == labels[:, None, None]
return binary_masks.astype(np.float32), labels.astype(np.int64)
def encode_inputs(
self,
pixel_values_list: List["torch.Tensor"],
annotations: Optional[List[Dict]] = None,
pad_and_return_pixel_mask: Optional[bool] = True,
pixel_values_list: List["np.ndarray"],
segmentation_maps: ImageInput = None,
pad_and_return_pixel_mask: bool = True,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
Args:
pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
annotations (`Dict`, `List[Dict]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys:
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
@ -315,7 +373,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
dictionary mapping instance ids to label ids to create a semantic segmentation map.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
@ -325,13 +388,29 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
model (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
`annotations` are provided).
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
(when `annotations` are provided).
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
"""
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
annotations = None
if segmentation_maps is not None:
segmentation_maps = map(np.array, segmentation_maps)
converted_segmentation_maps = []
for segmentation_map in segmentation_maps:
converted_segmentation_map = self.convert_segmentation_map_to_binary_masks(
segmentation_map, instance_id_to_semantic_id
)
converted_segmentation_maps.append(converted_segmentation_map)
annotations = []
for mask, classes in converted_segmentation_maps:
annotations.append({"masks": mask, "classes": classes})
channels, height, width = max_size
pixel_values = []
pixel_mask = []
@ -339,35 +418,37 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
class_labels = []
for idx, image in enumerate(pixel_values_list):
# create padded image
if pad_and_return_pixel_mask:
padded_image = np.zeros((channels, height, width), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
image = padded_image
padded_image = np.zeros((channels, height, width), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
image = padded_image
pixel_values.append(image)
# if we have a target, pad it
if annotations:
annotation = annotations[idx]
masks = annotation["masks"]
if pad_and_return_pixel_mask:
padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype)
padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks)
masks = padded_masks
mask_labels.append(masks)
class_labels.append(annotation["labels"])
if pad_and_return_pixel_mask:
# create pixel mask
mask = np.zeros((height, width), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
# pad mask with `ignore_index`
masks = np.pad(
masks,
((0, 0), (0, height - masks.shape[1]), (0, width - masks.shape[2])),
constant_values=self.ignore_index,
)
annotation["masks"] = masks
# create pixel mask
mask = np.zeros((height, width), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
# return as BatchFeature
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
if annotations:
data["mask_labels"] = mask_labels
data["class_labels"] = class_labels
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
# we cannot batch them since they don't share a common class size
if annotations:
for label in annotations:
mask_labels.append(torch.from_numpy(label["masks"]))
class_labels.append(torch.from_numpy(label["classes"]))
encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels
return encoded_inputs

View File

@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift
if self.shift_size > 0:
@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module):
Params:
masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_classes` with the
A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits.
class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the
@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module):
indices: List[Tuple[np.array]] = []
preds_masks = masks_queries_logits
preds_probs = class_queries_logits.softmax(dim=-1)
# downsample all masks in one go -> save memory
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
preds_probs = class_queries_logits
# iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
# downsample the target mask, save memory
target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels]
# flatten spatial dimension "q h w -> q (h w)"
num_queries, height, width = pred_mask.shape
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
# same for target_mask "c h w -> c (h w)"
num_channels, height, width = target_mask.shape
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W]
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
# compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES]
# Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module):
class MaskFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
num_labels: int,
matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module):
matched ground-truth / prediction (supervise class and mask)
Args:
num_classes (`int`):
num_labels (`int`):
The number of classes.
matcher (`MaskFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module):
super().__init__()
requires_backends(self, ["scipy"])
self.num_classes = num_classes
self.num_labels = num_labels
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels(
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array]
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
class_labels (`Dict[str, Tensor]`):
A tensor of shape `batch_size, num_classes`
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module):
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
# shape = [BATCH, N_QUERIES]
# shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = [BATCH, N_QUERIES]
# shape = (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q"
pred_logits_permuted = pred_logits.permute(0, 2, 1)
loss_ce = criterion(pred_logits_permuted, target_classes)
# target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int
self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module):
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W]
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W]
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W]
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W]
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size, we have to add one dim to use interpolate
pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module):
pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(pred_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks),
@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module):
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def get_loss(self, loss, outputs, labels, indices, num_masks):
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
if loss not in loss_map:
raise KeyError(f"{loss} not in loss_map")
return loss_map[loss](outputs, labels, indices, num_masks)
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
) -> Dict[str, Tensor]:
"""
This performs the loss computation.
@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
A tensor of shape `batch_size, num_queries, num_labels`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes, height, width`
class_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes`
List of mask labels of shape `(labels, height, width)`.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module):
for each auxiliary predictions.
"""
# Retrieve the matching between the outputs of the last layer and the labels
# retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device)
# Compute all the requested losses
# compute the average number of target masks for normalization purposes
num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module):
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks = class_labels.shape[0]
"""
Computes the average number of target masks accross the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
)
# weight each loss by `self.weight_dict[<LOSS_NAME>]`
weighted_loss_dict: Dict[str, Tensor] = {
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict
}
return weighted_loss_dict
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[Tensor] = None,
class_labels: Optional[Tensor] = None,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput:
r"""
mask_labels (`torch.FloatTensor`, *optional*):
The target mask of shape `(num_classes, height, width)`.
class_labels (`torch.LongTensor`, *optional*):
The target labels of shape `(num_classes)`.
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns:

View File

@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=True,
ignore_index=255,
):
self.parent = parent
self.batch_size = batch_size
@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
self.num_classes = 2
self.height = 3
self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_feat_extract_dict(self):
return {
@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
"image_mean": self.image_mean,
"image_std": self.image_std,
"size_divisibility": self.size_divisibility,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
}
def get_expected_values(self, image_inputs, batched=False):
@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size"))
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
self.assertTrue(hasattr(feature_extractor, "num_labels"))
def test_batch_feature(self):
pass
@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
feature_extractor_2 = self.feature_extraction_class(
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
)
def comm_get_feature_extractor_inputs(self, with_annotations=False):
def comm_get_feature_extractor_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# prepare image and target
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
num_labels = self.feature_extract_tester.num_labels
annotations = None
if with_annotations:
annotations = [
{
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32),
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64),
instance_id_to_semantic_id = None
if with_segmentation_maps:
high = num_labels
if is_instance_map:
high * 2
labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
}
for _ in range(batch_size)
]
annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True)
inputs = feature_extractor(
image_inputs,
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs
def test_init_without_params(self):
pass
def test_with_size_divisibility(self):
size_divisibilities = [8, 16, 32]
weird_input_sizes = [(407, 802), (582, 1094)]
@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
def test_call_with_numpy_annotations(self):
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
inputs = self.comm_get_feature_extractor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
)
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True)
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for el in inputs.values():
self.assertEqual(el.shape[0], batch_size)
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2])
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
self.assertEqual(mask_labels.shape[1], num_classes)
common()
common(is_instance_map=True)
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_segmentation(outputs)
@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
)
def test_post_process_semantic_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)

View File

@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self):
def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
annotations=[
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
return_tensors="pt",
).to(torch_device)
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad():
outputs = model(**inputs)