diff --git a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py index d4041ed59aa..045d2bc0f51 100644 --- a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py @@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter: def __call__(self, original_config: object) -> MaskFormerFeatureExtractor: model = original_config.MODEL model_input = original_config.INPUT + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0]) return MaskFormerFeatureExtractor( image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), size=model_input.MIN_SIZE_TEST, max_size=model_input.MAX_SIZE_TEST, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + ignore_index=dataset_catalog.ignore_label, size_divisibility=32, # 32 is required by swin ) @@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter: yield config, checkpoint -def test(original_model, our_model: MaskFormerForInstanceSegmentation): +def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor): with torch.no_grad(): original_model = original_model.eval() @@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation): our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x) - feature_extractor = MaskFormerFeatureExtractor() - our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384)) assert torch.allclose( @@ -707,7 +708,7 @@ if __name__ == "__main__": mask_former_for_instance_segmentation ) - test(original_model, mask_former_for_instance_segmentation) + test(original_model, mask_former_for_instance_segmentation, feature_extractor) model_name = get_name(checkpoint_file) logger.info(f"🪄 Saving {model_name}") diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py index bd8adc04d0b..5e466f2ddb0 100644 --- a/src/transformers/models/maskformer/feature_extraction_maskformer.py +++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py @@ -54,6 +54,10 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM max_size (`int`, *optional*, defaults to 1333): The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. size_divisibility (`int`, *optional*, defaults to 32): Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in Swin Transformer. @@ -64,8 +68,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the ImageNet std. - ignore_index (`int`, *optional*, default to 255): - Value of the index (label) to ignore. + ignore_index (`int`, *optional*): + Value of the index (label) to be removed from the segmentation maps. + reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by `ignore_index`. """ @@ -76,24 +84,28 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM do_resize=True, size=800, max_size=1333, + resample=Image.BILINEAR, size_divisibility=32, do_normalize=True, image_mean=None, image_std=None, - ignore_index=255, + ignore_index=None, + reduce_labels=False, **kwargs ): super().__init__(**kwargs) self.do_resize = do_resize self.size = size self.max_size = max_size + self.resample = resample self.size_divisibility = size_divisibility - self.ignore_index = ignore_index self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std + self.ignore_index = ignore_index + self.reduce_labels = reduce_labels - def _resize(self, image, size, target=None, max_size=None): + def _resize_with_size_divisibility(self, image, size, target=None, max_size=None): """ Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int, smaller edge of the image will be matched to this number. @@ -138,30 +150,19 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility size = (width, height) - rescaled_image = self.resize(image, size=size) + image = self.resize(image, size=size, resample=self.resample) - has_target = target is not None + if target is not None: + target = self.resize(target, size=size, resample=Image.NEAREST) - if has_target: - target = target.copy() - # store original_size - target["original_size"] = image.size - if "masks" in target: - masks = torch.from_numpy(target["masks"])[:, None].float() - # use PyTorch as current workaround - # TODO replace by self.resize - interpolated_masks = ( - nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5 - ).float() - target["masks"] = interpolated_masks.numpy() - - return rescaled_image, target + return image, target def __call__( self, images: ImageInput, - annotations: Union[List[Dict], List[List[Dict]]] = None, + segmentation_maps: ImageInput = None, pad_and_return_pixel_mask: Optional[bool] = True, + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchFeature: @@ -170,6 +171,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are real/which are padding. + MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass @@ -183,10 +190,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - annotations (`Dict`, `List[Dict]`, *optional*): - The corresponding annotations as dictionary of numpy arrays with the following keys: - - **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`. - - **labels** (`np.ndarray`) The target labels of shape `(num_classes)`. + segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): + Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations. pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): Whether or not to pad images up to the largest image in a batch and create a pixel mask. @@ -196,7 +201,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - return_tensors (`str` or [`~utils.TensorType`], *optional*): + instance_id_to_semantic_id (`Dict[int, int]`, *optional*): + If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an + instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a + dictionary mapping instance ids to label ids to create a semantic segmentation map. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` objects. @@ -206,15 +216,16 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM - **pixel_values** -- Pixel values to be fed to a model. - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if *"pixel_mask"* is in `self.model_input_names`). - - **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a - model (when `annotations` are provided). - - **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when - `annotations` are provided). + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. """ # Input type checking for clearer error valid_images = False - valid_annotations = False + valid_segmentation_maps = False # Check that images has a valid type if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): @@ -228,6 +239,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." ) + # Check that segmentation maps has a valid type + if segmentation_maps is not None: + if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps): + valid_segmentation_maps = True + elif isinstance(segmentation_maps, (list, tuple)): + if ( + len(segmentation_maps) == 0 + or isinstance(segmentation_maps[0], (Image.Image, np.ndarray)) + or is_torch_tensor(segmentation_maps[0]) + ): + valid_segmentation_maps = True + + if not valid_segmentation_maps: + raise ValueError( + "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example)," + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) is_batched = bool( isinstance(images, (list, tuple)) @@ -236,35 +264,33 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM if not is_batched: images = [images] - if annotations is not None: - annotations = [annotations] - - # Check that annotations has a valid type - if annotations is not None: - valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0] - if not valid_annotations: - raise ValueError( - "Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)." - "The annotations must be numpy arrays in the following format:" - "{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}" - ) + if segmentation_maps is not None: + segmentation_maps = [segmentation_maps] # transformations (resizing + normalization) if self.do_resize and self.size is not None: - if annotations is not None: - for idx, (image, target) in enumerate(zip(images, annotations)): - image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size) + if segmentation_maps is not None: + for idx, (image, target) in enumerate(zip(images, segmentation_maps)): + image, target = self._resize_with_size_divisibility( + image=image, target=target, size=self.size, max_size=self.max_size + ) images[idx] = image - annotations[idx] = target + segmentation_maps[idx] = target else: for idx, image in enumerate(images): - images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0] + images[idx] = self._resize_with_size_divisibility( + image=image, target=None, size=self.size, max_size=self.max_size + )[0] if self.do_normalize: images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] # NOTE I will be always forced to pad them them since they have to be stacked in the batch dim encoded_inputs = self.encode_inputs( - images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors + images, + segmentation_maps, + pad_and_return_pixel_mask, + instance_id_to_semantic_id=instance_id_to_semantic_id, + return_tensors=return_tensors, ) # Convert to TensorType @@ -287,25 +313,57 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM maxes[index] = max(maxes[index], item) return maxes + def convert_segmentation_map_to_binary_masks( + self, + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ): + if self.reduce_labels: + if self.ignore_index is None: + raise ValueError("`ignore_index` must be set when `reduce_labels` is `True`.") + segmentation_map[segmentation_map == 0] = self.ignore_index + # instances ids start from 1! + segmentation_map -= 1 + segmentation_map[segmentation_map == self.ignore_index - 1] = self.ignore_index + + if instance_id_to_semantic_id is not None: + # segmentation_map will be treated as an instance segmentation map where each pixel is a instance id + # thus it has to be converted to a semantic segmentation map + for instance_id, label_id in instance_id_to_semantic_id.items(): + segmentation_map[segmentation_map == instance_id] = label_id + # get all the labels in the image + labels = np.unique(segmentation_map) + # remove ignore index (if we have one) + if self.ignore_index is not None: + labels = labels[labels != self.ignore_index] + # helping broadcast by making mask [1,W,H] and labels [C, 1, 1] + binary_masks = segmentation_map[None] == labels[:, None, None] + return binary_masks.astype(np.float32), labels.astype(np.int64) + def encode_inputs( self, - pixel_values_list: List["torch.Tensor"], - annotations: Optional[List[Dict]] = None, - pad_and_return_pixel_mask: Optional[bool] = True, + pixel_values_list: List["np.ndarray"], + segmentation_maps: ImageInput = None, + pad_and_return_pixel_mask: bool = True, + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, return_tensors: Optional[Union[str, TensorType]] = None, ): """ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + Args: pixel_values_list (`List[torch.Tensor]`): List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, width)`. - annotations (`Dict`, `List[Dict]`, *optional*): - The corresponding annotations as dictionary of numpy arrays with the following keys: - - **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`. - - **labels** (`np.ndarray`) The target labels of shape `(num_classes)`. + segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): Whether or not to pad images up to the largest image in a batch and create a pixel mask. @@ -315,7 +373,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - return_tensors (`str` or [`~utils.TensorType`], *optional*): + instance_id_to_semantic_id (`Dict[int, int]`, *optional*): + If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an + instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a + dictionary mapping instance ids to label ids to create a semantic segmentation map. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` objects. @@ -325,13 +388,29 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM - **pixel_values** -- Pixel values to be fed to a model. - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if *"pixel_mask"* is in `self.model_input_names`). - - **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a - model (when `annotations` are provided). - - **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when - `annotations` are provided). + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. """ max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list]) + + annotations = None + if segmentation_maps is not None: + segmentation_maps = map(np.array, segmentation_maps) + converted_segmentation_maps = [] + for segmentation_map in segmentation_maps: + converted_segmentation_map = self.convert_segmentation_map_to_binary_masks( + segmentation_map, instance_id_to_semantic_id + ) + converted_segmentation_maps.append(converted_segmentation_map) + + annotations = [] + for mask, classes in converted_segmentation_maps: + annotations.append({"masks": mask, "classes": classes}) + channels, height, width = max_size pixel_values = [] pixel_mask = [] @@ -339,35 +418,37 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM class_labels = [] for idx, image in enumerate(pixel_values_list): # create padded image - if pad_and_return_pixel_mask: - padded_image = np.zeros((channels, height, width), dtype=np.float32) - padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image) - image = padded_image + padded_image = np.zeros((channels, height, width), dtype=np.float32) + padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image) + image = padded_image pixel_values.append(image) # if we have a target, pad it if annotations: annotation = annotations[idx] masks = annotation["masks"] - if pad_and_return_pixel_mask: - padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype) - padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks) - masks = padded_masks - mask_labels.append(masks) - class_labels.append(annotation["labels"]) - if pad_and_return_pixel_mask: - # create pixel mask - mask = np.zeros((height, width), dtype=np.int64) - mask[: image.shape[1], : image.shape[2]] = True - pixel_mask.append(mask) + # pad mask with `ignore_index` + masks = np.pad( + masks, + ((0, 0), (0, height - masks.shape[1]), (0, width - masks.shape[2])), + constant_values=self.ignore_index, + ) + annotation["masks"] = masks + # create pixel mask + mask = np.zeros((height, width), dtype=np.int64) + mask[: image.shape[1], : image.shape[2]] = True + pixel_mask.append(mask) # return as BatchFeature data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} - - if annotations: - data["mask_labels"] = mask_labels - data["class_labels"] = class_labels - encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + # we cannot batch them since they don't share a common class size + if annotations: + for label in annotations: + mask_labels.append(torch.from_numpy(label["masks"])) + class_labels.append(torch.from_numpy(label["classes"])) + + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels return encoded_inputs diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 450ba50b59e..cc793d88657 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each query. masks_queries_logits (`torch.FloatTensor`): - A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each query. Note the `+ 1` is needed because we incorporate the null class. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Last hidden states (final feature map) of the last stage of the encoder model (backbone). @@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: """ inputs = inputs.sigmoid().flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) - # using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss @@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module): outputs = self_attention_outputs[1:] # add self attentions if we output attention weights attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) - shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C + shifted_windows = window_reverse( + attention_windows, self.window_size, height_pad, width_pad + ) # B height' width' C # reverse cyclic shift if self.shift_size > 0: @@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module): Params: masks_queries_logits (`torch.Tensor`): - A tensor` of dim `batch_size, num_queries, num_classes` with the + A tensor` of dim `batch_size, num_queries, num_labels` with the classification logits. class_queries_logits (`torch.Tensor`): A tensor` of dim `batch_size, num_queries, height, width` with the @@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module): indices: List[Tuple[np.array]] = [] preds_masks = masks_queries_logits - preds_probs = class_queries_logits.softmax(dim=-1) - # downsample all masks in one go -> save memory - mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest") + preds_probs = class_queries_logits # iterate through batch size for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): + # downsample the target mask, save memory + target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest") + pred_probs = pred_probs.softmax(-1) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -pred_probs[:, labels] # flatten spatial dimension "q h w -> q (h w)" - num_queries, height, width = pred_mask.shape - pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W] + pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width] # same for target_mask "c h w -> c (h w)" - num_channels, height, width = target_mask.shape - target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W] - # compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES] + target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width] + # compute the focal loss between each mask pairs -> shape (num_queries, num_labels) cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat) - # Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES] + # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels) cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat) # final cost matrix cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice @@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module): class MaskFormerLoss(nn.Module): def __init__( self, - num_classes: int, + num_labels: int, matcher: MaskFormerHungarianMatcher, weight_dict: Dict[str, float], eos_coef: float, @@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module): matched ground-truth / prediction (supervise class and mask) Args: - num_classes (`int`): + num_labels (`int`): The number of classes. matcher (`MaskFormerHungarianMatcher`): A torch module that computes the assigments between the predictions and labels. @@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module): super().__init__() requires_backends(self, ["scipy"]) - self.num_classes = num_classes + self.num_labels = num_labels self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef - empty_weight = torch.ones(self.num_classes + 1) + empty_weight = torch.ones(self.num_labels + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) + def _max_by_axis(self, the_list: List[List[int]]) -> List[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + batch_size = len(tensors) + # compute finel size + batch_shape = [batch_size] + max_size + b, _, h, w = batch_shape + # get metadata + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + def loss_labels( - self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array] + self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] ) -> Dict[str, Tensor]: """Compute the losses related to the labels using cross entropy. Args: class_queries_logits (`torch.Tensor`): - A tensor of shape `batch_size, num_queries, num_classes` - class_labels (`Dict[str, Tensor]`): - A tensor of shape `batch_size, num_classes` + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. indices (`Tuple[np.array])`: The indices computed by the Hungarian matcher. @@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module): batch_size, num_queries, _ = pred_logits.shape criterion = nn.CrossEntropyLoss(weight=self.empty_weight) idx = self._get_predictions_permutation_indices(indices) - # shape = [BATCH, N_QUERIES] + # shape = (batch_size, num_queries) target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) - # shape = [BATCH, N_QUERIES] + # shape = (batch_size, num_queries) target_classes = torch.full( - (batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device + (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device ) target_classes[idx] = target_classes_o - # target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q" - pred_logits_permuted = pred_logits.permute(0, 2, 1) - loss_ce = criterion(pred_logits_permuted, target_classes) + # target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q" + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) losses = {"loss_cross_entropy": loss_ce} return losses def loss_masks( - self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int + self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int ) -> Dict[str, Tensor]: """Compute the losses related to the masks using focal and dice loss. @@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module): masks_queries_logits (`torch.Tensor`): A tensor of shape `batch_size, num_queries, height, width` mask_labels (`torch.Tensor`): - A tensor of shape `batch_size, num_queries, height, width` + List of mask labels of shape `(labels, height, width)`. indices (`Tuple[np.array])`: The indices computed by the Hungarian matcher. num_masks (`int)`: @@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module): """ src_idx = self._get_predictions_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices) - pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W] - pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W] - target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W] - target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W] + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] # upsample predictions to the target size, we have to add one dim to use interpolate pred_masks = nn.functional.interpolate( pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False @@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module): pred_masks = pred_masks[:, 0].flatten(1) target_masks = target_masks.flatten(1) - target_masks = target_masks.view(pred_masks.shape) losses = { "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks), "loss_dice": dice_loss(pred_masks, target_masks, num_masks), @@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module): target_indices = torch.cat([tgt for (_, tgt) in indices]) return batch_indices, target_indices - def get_loss(self, loss, outputs, labels, indices, num_masks): - loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} - if loss not in loss_map: - raise KeyError(f"{loss} not in loss_map") - return loss_map[loss](outputs, labels, indices, num_masks) - def forward( self, - masks_queries_logits: torch.Tensor, - class_queries_logits: torch.Tensor, - mask_labels: torch.Tensor, - class_labels: torch.Tensor, - auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: List[Tensor], + class_labels: List[Tensor], + auxiliary_predictions: Optional[Dict[str, Tensor]] = None, ) -> Dict[str, Tensor]: """ This performs the loss computation. @@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module): masks_queries_logits (`torch.Tensor`): A tensor of shape `batch_size, num_queries, height, width` class_queries_logits (`torch.Tensor`): - A tensor of shape `batch_size, num_queries, num_classes` + A tensor of shape `batch_size, num_queries, num_labels` mask_labels (`torch.Tensor`): - A tensor of shape `batch_size, num_classes, height, width` - class_labels (`torch.Tensor`): - A tensor of shape `batch_size, num_classes` + List of mask labels of shape `(labels, height, width)`. + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the inner layers of the Detr's Decoder. @@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module): for each auxiliary predictions. """ - # Retrieve the matching between the outputs of the last layer and the labels + # retrieve the matching between the outputs of the last layer and the labels indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) - - # Compute the average number of target masks accross all nodes, for normalization purposes - num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device) - - # Compute all the requested losses + # compute the average number of target masks for normalization purposes + num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses losses: Dict[str, Tensor] = { **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), **self.loss_labels(class_queries_logits, class_labels, indices), } - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. if auxiliary_predictions is not None: for idx, aux_outputs in enumerate(auxiliary_predictions): masks_queries_logits = aux_outputs["masks_queries_logits"] @@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module): return losses def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: - # Compute the average number of target masks accross all nodes, for normalization purposes - num_masks = class_labels.shape[0] + """ + Computes the average number of target masks accross the batch, for normalization purposes. + """ + num_masks = sum([len(classes) for classes in class_labels]) num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) return num_masks_pt @@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): loss_dict: Dict[str, Tensor] = self.criterion( masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits ) - # weight each loss by `self.weight_dict[]` - weighted_loss_dict: Dict[str, Tensor] = { - k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict - } - return weighted_loss_dict + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: return sum(loss_dict.values()) @@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): def forward( self, pixel_values: Tensor, - mask_labels: Optional[Tensor] = None, - class_labels: Optional[Tensor] = None, + mask_labels: Optional[List[Tensor]] = None, + class_labels: Optional[List[Tensor]] = None, pixel_mask: Optional[Tensor] = None, output_auxiliary_logits: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): return_dict: Optional[bool] = None, ) -> MaskFormerForInstanceSegmentationOutput: r""" - mask_labels (`torch.FloatTensor`, *optional*): - The target mask of shape `(num_classes, height, width)`. - class_labels (`torch.LongTensor`, *optional*): - The target labels of shape `(num_classes)`. + mask_labels (`List[torch.Tensor]`, *optional*): + List of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`List[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. Returns: diff --git a/tests/maskformer/test_feature_extraction_maskformer.py b/tests/maskformer/test_feature_extraction_maskformer.py index c44899ebf50..259954643fc 100644 --- a/tests/maskformer/test_feature_extraction_maskformer.py +++ b/tests/maskformer/test_feature_extraction_maskformer.py @@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], + num_labels=10, + reduce_labels=True, + ignore_index=255, ): self.parent = parent self.batch_size = batch_size @@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): self.num_classes = 2 self.height = 3 self.width = 4 + self.num_labels = num_labels + self.reduce_labels = reduce_labels + self.ignore_index = ignore_index def prepare_feat_extract_dict(self): return { @@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): "image_mean": self.image_mean, "image_std": self.image_std, "size_divisibility": self.size_divisibility, + "num_labels": self.num_labels, + "reduce_labels": self.reduce_labels, + "ignore_index": self.ignore_index, } def get_expected_values(self, image_inputs, batched=False): @@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "max_size")) + self.assertTrue(hasattr(feature_extractor, "ignore_index")) + self.assertTrue(hasattr(feature_extractor, "num_labels")) def test_batch_feature(self): pass @@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest def test_equivalence_pad_and_create_pixel_mask(self): # Initialize feature_extractors feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict) - feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False) + feature_extractor_2 = self.feature_extraction_class( + do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes + ) # create random PyTorch tensors image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) for image in image_inputs: @@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) ) - def comm_get_feature_extractor_inputs(self, with_annotations=False): + def comm_get_feature_extractor_inputs( + self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np" + ): feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) # prepare image and target - num_classes = 8 batch_size = self.feature_extract_tester.batch_size + num_labels = self.feature_extract_tester.num_labels annotations = None - - if with_annotations: - annotations = [ - { - "masks": np.random.rand(num_classes, 384, 384).astype(np.float32), - "labels": (np.random.rand(num_classes) > 0.5).astype(np.int64), + instance_id_to_semantic_id = None + if with_segmentation_maps: + high = num_labels + if is_instance_map: + high * 2 + labels_expanded = list(range(num_labels)) * 2 + instance_id_to_semantic_id = { + instance_id: label_id for instance_id, label_id in enumerate(labels_expanded) } - for _ in range(batch_size) - ] + annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)] + if segmentation_type == "pil": + annotations = [Image.fromarray(annotation) for annotation in annotations] image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) - - inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True) + inputs = feature_extractor( + image_inputs, + annotations, + return_tensors="pt", + instance_id_to_semantic_id=instance_id_to_semantic_id, + pad_and_return_pixel_mask=True, + ) return inputs + def test_init_without_params(self): + pass + def test_with_size_divisibility(self): size_divisibilities = [8, 16, 32] weird_input_sizes = [(407, 802), (582, 1094)] @@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0) self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0) - def test_call_with_numpy_annotations(self): - num_classes = 8 - batch_size = self.feature_extract_tester.batch_size + def test_call_with_segmentation_maps(self): + def common(is_instance_map=False, segmentation_type=None): + inputs = self.comm_get_feature_extractor_inputs( + with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type + ) - inputs = self.comm_get_feature_extractor_inputs(with_annotations=True) + mask_labels = inputs["mask_labels"] + class_labels = inputs["class_labels"] + pixel_values = inputs["pixel_values"] - # check the batch_size - for el in inputs.values(): - self.assertEqual(el.shape[0], batch_size) + # check the batch_size + for mask_label, class_label in zip(mask_labels, class_labels): + self.assertEqual(mask_label.shape[0], class_label.shape[0]) + # this ensure padding has happened + self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:]) - pixel_values = inputs["pixel_values"] - mask_labels = inputs["mask_labels"] - class_labels = inputs["class_labels"] - - self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2]) - self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1]) - self.assertEqual(mask_labels.shape[1], class_labels.shape[1]) - self.assertEqual(mask_labels.shape[1], num_classes) + common() + common(is_instance_map=True) + common(is_instance_map=False, segmentation_type="pil") + common(is_instance_map=True, segmentation_type="pil") def test_post_process_segmentation(self): - fature_extractor = self.feature_extraction_class() + fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) outputs = self.feature_extract_tester.get_fake_maskformer_outputs() segmentation = fature_extractor.post_process_segmentation(outputs) @@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ) def test_post_process_semantic_segmentation(self): - fature_extractor = self.feature_extraction_class() + fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) outputs = self.feature_extract_tester.get_fake_maskformer_outputs() segmentation = fature_extractor.post_process_semantic_segmentation(outputs) @@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size)) def test_post_process_panoptic_segmentation(self): - fature_extractor = self.feature_extraction_class() + fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) outputs = self.feature_extract_tester.get_fake_maskformer_outputs() segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0) diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py index a27fe41a69c..50dbecb8de4 100644 --- a/tests/maskformer/test_modeling_maskformer.py +++ b/tests/maskformer/test_modeling_maskformer.py @@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) - def test_with_annotations_and_loss(self): + def test_with_segmentation_maps_and_loss(self): model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() feature_extractor = self.default_feature_extractor inputs = feature_extractor( [np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))], - annotations=[ - {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, - {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, - ], + segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)], return_tensors="pt", - ).to(torch_device) + ) + + inputs["pixel_values"] = inputs["pixel_values"].to(torch_device) + inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]] + inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]] with torch.no_grad(): outputs = model(**inputs)