mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Feature Extractor accepts segmentation_maps
(#15964)
* feature extractor accepts * resolved conversations * added examples in test for ADE20K * num_classes -> num_labels * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * resolving conversations * resolving conversations * removed ADE * CI * minor changes in conversion script * reduce_labels in feature extractor * minor changes * correct preprocess for instace segmentation maps * minor changes * minor changes * CI * debugging * better padding * going to update labels inside the model * going to update labels inside the model * minor changes * tests * removed changes in feature_extractor_utils * conversation * conversation * example in feature extractor * more docstring in modeling * test * make style * doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c2f8eaf6bc
commit
c4deb7b3ae
@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter:
|
||||
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
|
||||
model = original_config.MODEL
|
||||
model_input = original_config.INPUT
|
||||
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
|
||||
|
||||
return MaskFormerFeatureExtractor(
|
||||
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
|
||||
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
|
||||
size=model_input.MIN_SIZE_TEST,
|
||||
max_size=model_input.MAX_SIZE_TEST,
|
||||
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
|
||||
ignore_index=dataset_catalog.ignore_label,
|
||||
size_divisibility=32, # 32 is required by swin
|
||||
)
|
||||
|
||||
@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter:
|
||||
yield config, checkpoint
|
||||
|
||||
|
||||
def test(original_model, our_model: MaskFormerForInstanceSegmentation):
|
||||
def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor):
|
||||
with torch.no_grad():
|
||||
|
||||
original_model = original_model.eval()
|
||||
@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation):
|
||||
|
||||
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
|
||||
|
||||
feature_extractor = MaskFormerFeatureExtractor()
|
||||
|
||||
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
|
||||
|
||||
assert torch.allclose(
|
||||
@ -707,7 +708,7 @@ if __name__ == "__main__":
|
||||
mask_former_for_instance_segmentation
|
||||
)
|
||||
|
||||
test(original_model, mask_former_for_instance_segmentation)
|
||||
test(original_model, mask_former_for_instance_segmentation, feature_extractor)
|
||||
|
||||
model_name = get_name(checkpoint_file)
|
||||
logger.info(f"🪄 Saving {model_name}")
|
||||
|
@ -54,6 +54,10 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
max_size (`int`, *optional*, defaults to 1333):
|
||||
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
|
||||
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
|
||||
if `do_resize` is set to `True`.
|
||||
size_divisibility (`int`, *optional*, defaults to 32):
|
||||
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
|
||||
Swin Transformer.
|
||||
@ -64,8 +68,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
|
||||
ImageNet std.
|
||||
ignore_index (`int`, *optional*, default to 255):
|
||||
Value of the index (label) to ignore.
|
||||
ignore_index (`int`, *optional*):
|
||||
Value of the index (label) to be removed from the segmentation maps.
|
||||
reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
||||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
||||
background label will be replaced by `ignore_index`.
|
||||
|
||||
"""
|
||||
|
||||
@ -76,24 +84,28 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
do_resize=True,
|
||||
size=800,
|
||||
max_size=1333,
|
||||
resample=Image.BILINEAR,
|
||||
size_divisibility=32,
|
||||
do_normalize=True,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
ignore_index=255,
|
||||
ignore_index=None,
|
||||
reduce_labels=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.max_size = max_size
|
||||
self.resample = resample
|
||||
self.size_divisibility = size_divisibility
|
||||
self.ignore_index = ignore_index
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
|
||||
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_labels = reduce_labels
|
||||
|
||||
def _resize(self, image, size, target=None, max_size=None):
|
||||
def _resize_with_size_divisibility(self, image, size, target=None, max_size=None):
|
||||
"""
|
||||
Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int,
|
||||
smaller edge of the image will be matched to this number.
|
||||
@ -138,30 +150,19 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility
|
||||
|
||||
size = (width, height)
|
||||
rescaled_image = self.resize(image, size=size)
|
||||
image = self.resize(image, size=size, resample=self.resample)
|
||||
|
||||
has_target = target is not None
|
||||
if target is not None:
|
||||
target = self.resize(target, size=size, resample=Image.NEAREST)
|
||||
|
||||
if has_target:
|
||||
target = target.copy()
|
||||
# store original_size
|
||||
target["original_size"] = image.size
|
||||
if "masks" in target:
|
||||
masks = torch.from_numpy(target["masks"])[:, None].float()
|
||||
# use PyTorch as current workaround
|
||||
# TODO replace by self.resize
|
||||
interpolated_masks = (
|
||||
nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5
|
||||
).float()
|
||||
target["masks"] = interpolated_masks.numpy()
|
||||
|
||||
return rescaled_image, target
|
||||
return image, target
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput,
|
||||
annotations: Union[List[Dict], List[List[Dict]]] = None,
|
||||
segmentation_maps: ImageInput = None,
|
||||
pad_and_return_pixel_mask: Optional[bool] = True,
|
||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
@ -170,6 +171,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are
|
||||
real/which are padding.
|
||||
|
||||
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
|
||||
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
|
||||
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
|
||||
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
|
||||
each mask.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
||||
@ -183,10 +190,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
|
||||
annotations (`Dict`, `List[Dict]`, *optional*):
|
||||
The corresponding annotations as dictionary of numpy arrays with the following keys:
|
||||
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
|
||||
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
|
||||
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||
|
||||
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
||||
@ -196,7 +201,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
- 1 for pixels that are real (i.e. **not masked**),
|
||||
- 0 for pixels that are padding (i.e. **masked**).
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
|
||||
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
|
||||
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
|
||||
dictionary mapping instance ids to label ids to create a semantic segmentation map.
|
||||
|
||||
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
@ -206,15 +216,16 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
|
||||
*"pixel_mask"* is in `self.model_input_names`).
|
||||
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
|
||||
model (when `annotations` are provided).
|
||||
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
|
||||
`annotations` are provided).
|
||||
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
|
||||
(when `annotations` are provided).
|
||||
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
|
||||
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
|
||||
`mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
|
||||
valid_images = False
|
||||
valid_annotations = False
|
||||
valid_segmentation_maps = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
@ -228,6 +239,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
||||
)
|
||||
# Check that segmentation maps has a valid type
|
||||
if segmentation_maps is not None:
|
||||
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
|
||||
valid_segmentation_maps = True
|
||||
elif isinstance(segmentation_maps, (list, tuple)):
|
||||
if (
|
||||
len(segmentation_maps) == 0
|
||||
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
|
||||
or is_torch_tensor(segmentation_maps[0])
|
||||
):
|
||||
valid_segmentation_maps = True
|
||||
|
||||
if not valid_segmentation_maps:
|
||||
raise ValueError(
|
||||
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
@ -236,35 +264,33 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
if annotations is not None:
|
||||
annotations = [annotations]
|
||||
|
||||
# Check that annotations has a valid type
|
||||
if annotations is not None:
|
||||
valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0]
|
||||
if not valid_annotations:
|
||||
raise ValueError(
|
||||
"Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)."
|
||||
"The annotations must be numpy arrays in the following format:"
|
||||
"{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}"
|
||||
)
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [segmentation_maps]
|
||||
|
||||
# transformations (resizing + normalization)
|
||||
if self.do_resize and self.size is not None:
|
||||
if annotations is not None:
|
||||
for idx, (image, target) in enumerate(zip(images, annotations)):
|
||||
image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size)
|
||||
if segmentation_maps is not None:
|
||||
for idx, (image, target) in enumerate(zip(images, segmentation_maps)):
|
||||
image, target = self._resize_with_size_divisibility(
|
||||
image=image, target=target, size=self.size, max_size=self.max_size
|
||||
)
|
||||
images[idx] = image
|
||||
annotations[idx] = target
|
||||
segmentation_maps[idx] = target
|
||||
else:
|
||||
for idx, image in enumerate(images):
|
||||
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]
|
||||
images[idx] = self._resize_with_size_divisibility(
|
||||
image=image, target=None, size=self.size, max_size=self.max_size
|
||||
)[0]
|
||||
|
||||
if self.do_normalize:
|
||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
||||
# NOTE I will be always forced to pad them them since they have to be stacked in the batch dim
|
||||
encoded_inputs = self.encode_inputs(
|
||||
images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors
|
||||
images,
|
||||
segmentation_maps,
|
||||
pad_and_return_pixel_mask,
|
||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# Convert to TensorType
|
||||
@ -287,25 +313,57 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
def convert_segmentation_map_to_binary_masks(
|
||||
self,
|
||||
segmentation_map: "np.ndarray",
|
||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||
):
|
||||
if self.reduce_labels:
|
||||
if self.ignore_index is None:
|
||||
raise ValueError("`ignore_index` must be set when `reduce_labels` is `True`.")
|
||||
segmentation_map[segmentation_map == 0] = self.ignore_index
|
||||
# instances ids start from 1!
|
||||
segmentation_map -= 1
|
||||
segmentation_map[segmentation_map == self.ignore_index - 1] = self.ignore_index
|
||||
|
||||
if instance_id_to_semantic_id is not None:
|
||||
# segmentation_map will be treated as an instance segmentation map where each pixel is a instance id
|
||||
# thus it has to be converted to a semantic segmentation map
|
||||
for instance_id, label_id in instance_id_to_semantic_id.items():
|
||||
segmentation_map[segmentation_map == instance_id] = label_id
|
||||
# get all the labels in the image
|
||||
labels = np.unique(segmentation_map)
|
||||
# remove ignore index (if we have one)
|
||||
if self.ignore_index is not None:
|
||||
labels = labels[labels != self.ignore_index]
|
||||
# helping broadcast by making mask [1,W,H] and labels [C, 1, 1]
|
||||
binary_masks = segmentation_map[None] == labels[:, None, None]
|
||||
return binary_masks.astype(np.float32), labels.astype(np.int64)
|
||||
|
||||
def encode_inputs(
|
||||
self,
|
||||
pixel_values_list: List["torch.Tensor"],
|
||||
annotations: Optional[List[Dict]] = None,
|
||||
pad_and_return_pixel_mask: Optional[bool] = True,
|
||||
pixel_values_list: List["np.ndarray"],
|
||||
segmentation_maps: ImageInput = None,
|
||||
pad_and_return_pixel_mask: bool = True,
|
||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
):
|
||||
"""
|
||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
||||
|
||||
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
|
||||
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
|
||||
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
|
||||
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
|
||||
each mask.
|
||||
|
||||
Args:
|
||||
pixel_values_list (`List[torch.Tensor]`):
|
||||
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
||||
width)`.
|
||||
|
||||
annotations (`Dict`, `List[Dict]`, *optional*):
|
||||
The corresponding annotations as dictionary of numpy arrays with the following keys:
|
||||
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
|
||||
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
|
||||
segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||
|
||||
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
||||
@ -315,7 +373,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
- 1 for pixels that are real (i.e. **not masked**),
|
||||
- 0 for pixels that are padding (i.e. **masked**).
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
|
||||
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
|
||||
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
|
||||
dictionary mapping instance ids to label ids to create a semantic segmentation map.
|
||||
|
||||
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
@ -325,13 +388,29 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
|
||||
*"pixel_mask"* is in `self.model_input_names`).
|
||||
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
|
||||
model (when `annotations` are provided).
|
||||
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
|
||||
`annotations` are provided).
|
||||
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
|
||||
(when `annotations` are provided).
|
||||
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
|
||||
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
|
||||
`mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
"""
|
||||
|
||||
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
|
||||
|
||||
annotations = None
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = map(np.array, segmentation_maps)
|
||||
converted_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
converted_segmentation_map = self.convert_segmentation_map_to_binary_masks(
|
||||
segmentation_map, instance_id_to_semantic_id
|
||||
)
|
||||
converted_segmentation_maps.append(converted_segmentation_map)
|
||||
|
||||
annotations = []
|
||||
for mask, classes in converted_segmentation_maps:
|
||||
annotations.append({"masks": mask, "classes": classes})
|
||||
|
||||
channels, height, width = max_size
|
||||
pixel_values = []
|
||||
pixel_mask = []
|
||||
@ -339,35 +418,37 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
class_labels = []
|
||||
for idx, image in enumerate(pixel_values_list):
|
||||
# create padded image
|
||||
if pad_and_return_pixel_mask:
|
||||
padded_image = np.zeros((channels, height, width), dtype=np.float32)
|
||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
||||
image = padded_image
|
||||
padded_image = np.zeros((channels, height, width), dtype=np.float32)
|
||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
||||
image = padded_image
|
||||
pixel_values.append(image)
|
||||
# if we have a target, pad it
|
||||
if annotations:
|
||||
annotation = annotations[idx]
|
||||
masks = annotation["masks"]
|
||||
if pad_and_return_pixel_mask:
|
||||
padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype)
|
||||
padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks)
|
||||
masks = padded_masks
|
||||
mask_labels.append(masks)
|
||||
class_labels.append(annotation["labels"])
|
||||
if pad_and_return_pixel_mask:
|
||||
# create pixel mask
|
||||
mask = np.zeros((height, width), dtype=np.int64)
|
||||
mask[: image.shape[1], : image.shape[2]] = True
|
||||
pixel_mask.append(mask)
|
||||
# pad mask with `ignore_index`
|
||||
masks = np.pad(
|
||||
masks,
|
||||
((0, 0), (0, height - masks.shape[1]), (0, width - masks.shape[2])),
|
||||
constant_values=self.ignore_index,
|
||||
)
|
||||
annotation["masks"] = masks
|
||||
# create pixel mask
|
||||
mask = np.zeros((height, width), dtype=np.int64)
|
||||
mask[: image.shape[1], : image.shape[2]] = True
|
||||
pixel_mask.append(mask)
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
||||
|
||||
if annotations:
|
||||
data["mask_labels"] = mask_labels
|
||||
data["class_labels"] = class_labels
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
# we cannot batch them since they don't share a common class size
|
||||
if annotations:
|
||||
for label in annotations:
|
||||
mask_labels.append(torch.from_numpy(label["masks"]))
|
||||
class_labels.append(torch.from_numpy(label["classes"]))
|
||||
|
||||
encoded_inputs["mask_labels"] = mask_labels
|
||||
encoded_inputs["class_labels"] = class_labels
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
|
||||
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
|
||||
query.
|
||||
masks_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each
|
||||
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
|
||||
query. Note the `+ 1` is needed because we incorporate the null class.
|
||||
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
|
||||
@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
|
||||
"""
|
||||
inputs = inputs.sigmoid().flatten(1)
|
||||
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
|
||||
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix
|
||||
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
|
||||
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss
|
||||
@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module):
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
||||
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C
|
||||
shifted_windows = window_reverse(
|
||||
attention_windows, self.window_size, height_pad, width_pad
|
||||
) # B height' width' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module):
|
||||
|
||||
Params:
|
||||
masks_queries_logits (`torch.Tensor`):
|
||||
A tensor` of dim `batch_size, num_queries, num_classes` with the
|
||||
A tensor` of dim `batch_size, num_queries, num_labels` with the
|
||||
classification logits.
|
||||
class_queries_logits (`torch.Tensor`):
|
||||
A tensor` of dim `batch_size, num_queries, height, width` with the
|
||||
@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module):
|
||||
indices: List[Tuple[np.array]] = []
|
||||
|
||||
preds_masks = masks_queries_logits
|
||||
preds_probs = class_queries_logits.softmax(dim=-1)
|
||||
# downsample all masks in one go -> save memory
|
||||
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
|
||||
preds_probs = class_queries_logits
|
||||
# iterate through batch size
|
||||
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
|
||||
# downsample the target mask, save memory
|
||||
target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
|
||||
pred_probs = pred_probs.softmax(-1)
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
cost_class = -pred_probs[:, labels]
|
||||
# flatten spatial dimension "q h w -> q (h w)"
|
||||
num_queries, height, width = pred_mask.shape
|
||||
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
|
||||
pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
|
||||
# same for target_mask "c h w -> c (h w)"
|
||||
num_channels, height, width = target_mask.shape
|
||||
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W]
|
||||
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
|
||||
target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
|
||||
# compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
|
||||
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
|
||||
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES]
|
||||
# Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
|
||||
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
|
||||
# final cost matrix
|
||||
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
|
||||
@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module):
|
||||
class MaskFormerLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
num_labels: int,
|
||||
matcher: MaskFormerHungarianMatcher,
|
||||
weight_dict: Dict[str, float],
|
||||
eos_coef: float,
|
||||
@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module):
|
||||
matched ground-truth / prediction (supervise class and mask)
|
||||
|
||||
Args:
|
||||
num_classes (`int`):
|
||||
num_labels (`int`):
|
||||
The number of classes.
|
||||
matcher (`MaskFormerHungarianMatcher`):
|
||||
A torch module that computes the assigments between the predictions and labels.
|
||||
@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module):
|
||||
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
self.num_classes = num_classes
|
||||
self.num_labels = num_labels
|
||||
self.matcher = matcher
|
||||
self.weight_dict = weight_dict
|
||||
self.eos_coef = eos_coef
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight = torch.ones(self.num_labels + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
|
||||
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
# get the maximum size in the batch
|
||||
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
|
||||
batch_size = len(tensors)
|
||||
# compute finel size
|
||||
batch_shape = [batch_size] + max_size
|
||||
b, _, h, w = batch_shape
|
||||
# get metadata
|
||||
dtype = tensors[0].dtype
|
||||
device = tensors[0].device
|
||||
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
# pad the tensors to the size of the biggest one
|
||||
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
|
||||
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
|
||||
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
|
||||
|
||||
return padded_tensors, padding_masks
|
||||
|
||||
def loss_labels(
|
||||
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array]
|
||||
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Compute the losses related to the labels using cross entropy.
|
||||
|
||||
Args:
|
||||
class_queries_logits (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_queries, num_classes`
|
||||
class_labels (`Dict[str, Tensor]`):
|
||||
A tensor of shape `batch_size, num_classes`
|
||||
A tensor of shape `batch_size, num_queries, num_labels`
|
||||
class_labels (`List[torch.Tensor]`):
|
||||
List of class labels of shape `(labels)`.
|
||||
indices (`Tuple[np.array])`:
|
||||
The indices computed by the Hungarian matcher.
|
||||
|
||||
@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module):
|
||||
batch_size, num_queries, _ = pred_logits.shape
|
||||
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
|
||||
idx = self._get_predictions_permutation_indices(indices)
|
||||
# shape = [BATCH, N_QUERIES]
|
||||
# shape = (batch_size, num_queries)
|
||||
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
|
||||
# shape = [BATCH, N_QUERIES]
|
||||
# shape = (batch_size, num_queries)
|
||||
target_classes = torch.full(
|
||||
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
|
||||
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q"
|
||||
pred_logits_permuted = pred_logits.permute(0, 2, 1)
|
||||
loss_ce = criterion(pred_logits_permuted, target_classes)
|
||||
# target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
|
||||
pred_logits_transposed = pred_logits.transpose(1, 2)
|
||||
loss_ce = criterion(pred_logits_transposed, target_classes)
|
||||
losses = {"loss_cross_entropy": loss_ce}
|
||||
return losses
|
||||
|
||||
def loss_masks(
|
||||
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int
|
||||
self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Compute the losses related to the masks using focal and dice loss.
|
||||
|
||||
@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module):
|
||||
masks_queries_logits (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_queries, height, width`
|
||||
mask_labels (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_queries, height, width`
|
||||
List of mask labels of shape `(labels, height, width)`.
|
||||
indices (`Tuple[np.array])`:
|
||||
The indices computed by the Hungarian matcher.
|
||||
num_masks (`int)`:
|
||||
@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module):
|
||||
"""
|
||||
src_idx = self._get_predictions_permutation_indices(indices)
|
||||
tgt_idx = self._get_targets_permutation_indices(indices)
|
||||
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W]
|
||||
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W]
|
||||
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W]
|
||||
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W]
|
||||
# shape (batch_size * num_queries, height, width)
|
||||
pred_masks = masks_queries_logits[src_idx]
|
||||
# shape (batch_size, num_queries, height, width)
|
||||
# pad all and stack the targets to the num_labels dimension
|
||||
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
|
||||
target_masks = target_masks[tgt_idx]
|
||||
# upsample predictions to the target size, we have to add one dim to use interpolate
|
||||
pred_masks = nn.functional.interpolate(
|
||||
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module):
|
||||
pred_masks = pred_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(pred_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
|
||||
"loss_dice": dice_loss(pred_masks, target_masks, num_masks),
|
||||
@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module):
|
||||
target_indices = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_indices, target_indices
|
||||
|
||||
def get_loss(self, loss, outputs, labels, indices, num_masks):
|
||||
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
|
||||
if loss not in loss_map:
|
||||
raise KeyError(f"{loss} not in loss_map")
|
||||
return loss_map[loss](outputs, labels, indices, num_masks)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
masks_queries_logits: torch.Tensor,
|
||||
class_queries_logits: torch.Tensor,
|
||||
mask_labels: torch.Tensor,
|
||||
class_labels: torch.Tensor,
|
||||
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None,
|
||||
masks_queries_logits: Tensor,
|
||||
class_queries_logits: Tensor,
|
||||
mask_labels: List[Tensor],
|
||||
class_labels: List[Tensor],
|
||||
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Dict[str, Tensor]:
|
||||
"""
|
||||
This performs the loss computation.
|
||||
@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module):
|
||||
masks_queries_logits (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_queries, height, width`
|
||||
class_queries_logits (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_queries, num_classes`
|
||||
A tensor of shape `batch_size, num_queries, num_labels`
|
||||
mask_labels (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_classes, height, width`
|
||||
class_labels (`torch.Tensor`):
|
||||
A tensor of shape `batch_size, num_classes`
|
||||
List of mask labels of shape `(labels, height, width)`.
|
||||
class_labels (`List[torch.Tensor]`):
|
||||
List of class labels of shape `(labels)`.
|
||||
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
|
||||
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
|
||||
inner layers of the Detr's Decoder.
|
||||
@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module):
|
||||
for each auxiliary predictions.
|
||||
"""
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the labels
|
||||
# retrieve the matching between the outputs of the last layer and the labels
|
||||
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
|
||||
|
||||
# Compute the average number of target masks accross all nodes, for normalization purposes
|
||||
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device)
|
||||
|
||||
# Compute all the requested losses
|
||||
# compute the average number of target masks for normalization purposes
|
||||
num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
|
||||
# get all the losses
|
||||
losses: Dict[str, Tensor] = {
|
||||
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
|
||||
**self.loss_labels(class_queries_logits, class_labels, indices),
|
||||
}
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if auxiliary_predictions is not None:
|
||||
for idx, aux_outputs in enumerate(auxiliary_predictions):
|
||||
masks_queries_logits = aux_outputs["masks_queries_logits"]
|
||||
@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module):
|
||||
return losses
|
||||
|
||||
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
# Compute the average number of target masks accross all nodes, for normalization purposes
|
||||
num_masks = class_labels.shape[0]
|
||||
"""
|
||||
Computes the average number of target masks accross the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||
return num_masks_pt
|
||||
|
||||
@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
loss_dict: Dict[str, Tensor] = self.criterion(
|
||||
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
|
||||
)
|
||||
# weight each loss by `self.weight_dict[<LOSS_NAME>]`
|
||||
weighted_loss_dict: Dict[str, Tensor] = {
|
||||
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict
|
||||
}
|
||||
return weighted_loss_dict
|
||||
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
|
||||
for key, weight in self.weight_dict.items():
|
||||
for loss_key, loss in loss_dict.items():
|
||||
if key in loss_key:
|
||||
loss *= weight
|
||||
|
||||
return loss_dict
|
||||
|
||||
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
|
||||
return sum(loss_dict.values())
|
||||
@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
mask_labels: Optional[Tensor] = None,
|
||||
class_labels: Optional[Tensor] = None,
|
||||
mask_labels: Optional[List[Tensor]] = None,
|
||||
class_labels: Optional[List[Tensor]] = None,
|
||||
pixel_mask: Optional[Tensor] = None,
|
||||
output_auxiliary_logits: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> MaskFormerForInstanceSegmentationOutput:
|
||||
r"""
|
||||
mask_labels (`torch.FloatTensor`, *optional*):
|
||||
The target mask of shape `(num_classes, height, width)`.
|
||||
class_labels (`torch.LongTensor`, *optional*):
|
||||
The target labels of shape `(num_classes)`.
|
||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
|
||||
Returns:
|
||||
|
||||
|
@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
num_labels=10,
|
||||
reduce_labels=True,
|
||||
ignore_index=255,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
self.num_classes = 2
|
||||
self.height = 3
|
||||
self.width = 4
|
||||
self.num_labels = num_labels
|
||||
self.reduce_labels = reduce_labels
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"size_divisibility": self.size_divisibility,
|
||||
"num_labels": self.num_labels,
|
||||
"reduce_labels": self.reduce_labels,
|
||||
"ignore_index": self.ignore_index,
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "max_size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
|
||||
self.assertTrue(hasattr(feature_extractor, "num_labels"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
def test_equivalence_pad_and_create_pixel_mask(self):
|
||||
# Initialize feature_extractors
|
||||
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
|
||||
feature_extractor_2 = self.feature_extraction_class(
|
||||
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
|
||||
)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
|
||||
)
|
||||
|
||||
def comm_get_feature_extractor_inputs(self, with_annotations=False):
|
||||
def comm_get_feature_extractor_inputs(
|
||||
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
|
||||
):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# prepare image and target
|
||||
num_classes = 8
|
||||
batch_size = self.feature_extract_tester.batch_size
|
||||
num_labels = self.feature_extract_tester.num_labels
|
||||
annotations = None
|
||||
|
||||
if with_annotations:
|
||||
annotations = [
|
||||
{
|
||||
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32),
|
||||
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64),
|
||||
instance_id_to_semantic_id = None
|
||||
if with_segmentation_maps:
|
||||
high = num_labels
|
||||
if is_instance_map:
|
||||
high * 2
|
||||
labels_expanded = list(range(num_labels)) * 2
|
||||
instance_id_to_semantic_id = {
|
||||
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
|
||||
}
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
|
||||
if segmentation_type == "pil":
|
||||
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
||||
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
|
||||
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True)
|
||||
inputs = feature_extractor(
|
||||
image_inputs,
|
||||
annotations,
|
||||
return_tensors="pt",
|
||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||
pad_and_return_pixel_mask=True,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def test_init_without_params(self):
|
||||
pass
|
||||
|
||||
def test_with_size_divisibility(self):
|
||||
size_divisibilities = [8, 16, 32]
|
||||
weird_input_sizes = [(407, 802), (582, 1094)]
|
||||
@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
|
||||
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
|
||||
|
||||
def test_call_with_numpy_annotations(self):
|
||||
num_classes = 8
|
||||
batch_size = self.feature_extract_tester.batch_size
|
||||
def test_call_with_segmentation_maps(self):
|
||||
def common(is_instance_map=False, segmentation_type=None):
|
||||
inputs = self.comm_get_feature_extractor_inputs(
|
||||
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
|
||||
)
|
||||
|
||||
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True)
|
||||
mask_labels = inputs["mask_labels"]
|
||||
class_labels = inputs["class_labels"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
# check the batch_size
|
||||
for el in inputs.values():
|
||||
self.assertEqual(el.shape[0], batch_size)
|
||||
# check the batch_size
|
||||
for mask_label, class_label in zip(mask_labels, class_labels):
|
||||
self.assertEqual(mask_label.shape[0], class_label.shape[0])
|
||||
# this ensure padding has happened
|
||||
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
mask_labels = inputs["mask_labels"]
|
||||
class_labels = inputs["class_labels"]
|
||||
|
||||
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2])
|
||||
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
|
||||
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
|
||||
self.assertEqual(mask_labels.shape[1], num_classes)
|
||||
common()
|
||||
common(is_instance_map=True)
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
def test_post_process_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_segmentation(outputs)
|
||||
|
||||
@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
)
|
||||
|
||||
def test_post_process_semantic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||
@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
|
||||
|
||||
|
@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_with_annotations_and_loss(self):
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
inputs = feature_extractor(
|
||||
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
|
||||
annotations=[
|
||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||
],
|
||||
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
)
|
||||
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
|
||||
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
|
||||
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user