Improve DETR models (#19644)

* Improve DETR models

* Fix Deformable DETR loss and matcher

* Fixup

* Fix integration tests

* Improve variable names

* Apply suggestion

* Fix copies

* Fix DeformableDetrLoss

* Make Conditional DETR copy from Deformable DETR

* Copy from deformable detr's hungarian matcher

* Fix bug
This commit is contained in:
NielsRogge 2022-10-18 10:29:14 +02:00 committed by GitHub
parent 072dfdaee4
commit 90071fe42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 253 additions and 257 deletions

View File

@ -23,7 +23,7 @@ The abstract from the paper is the following:
Tips: Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images (and optional targets) for the model. This will instantiate a [`DetrFeatureExtractor`] behind the scenes. - One can use [`DeformableDetrFeatureExtractor`] to prepare images (and optional targets) for the model.
- Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR). - Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/deformable_detr_architecture.png" <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/deformable_detr_architecture.png"

View File

@ -110,6 +110,8 @@ class ConditionalDetrConfig(PretrainedConfig):
Relative weight of the generalized IoU loss in the object detection loss. Relative weight of the generalized IoU loss in the object detection loss.
eos_coefficient (`float`, *optional*, defaults to 0.1): eos_coefficient (`float`, *optional*, defaults to 0.1):
Relative classification weight of the 'no-object' class in the object detection loss. Relative classification weight of the 'no-object' class in the object detection loss.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples: Examples:

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1). (x_0, y_0, x_1, y_1).
""" """
x_c, y_c, w, h = x.unbind(-1) center_x, center_y, width, height = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)

View File

@ -33,7 +33,6 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
@ -44,9 +43,6 @@ from .configuration_conditional_detr import ConditionalDetrConfig
if is_scipy_available(): if is_scipy_available():
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
if is_vision_available():
from .feature_extraction_conditional_detr import center_to_corners_format
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
@ -679,11 +675,11 @@ class ConditionalDetrAttention(nn.Module):
self.out_proj = nn.Linear(out_dim, out_dim, bias=bias) self.out_proj = nn.Linear(out_dim, out_dim, bias=bias)
def _qk_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _v_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
@ -695,37 +691,38 @@ class ConditionalDetrAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size() batch_size, target_len, _ = hidden_states.size()
# get query proj # get query proj
query_states = hidden_states * self.scaling query_states = hidden_states * self.scaling
# get key, value proj # get key, value proj
key_states = self._qk_shape(key_states, -1, bsz) key_states = self._qk_shape(key_states, -1, batch_size)
value_states = self._v_shape(value_states, -1, bsz) value_states = self._v_shape(value_states, -1, batch_size)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
v_proj_shape = (bsz * self.num_heads, -1, self.v_head_dim) v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim)
query_states = self._qk_shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.view(*proj_shape)
value_states = value_states.view(*v_proj_shape) value_states = value_states.view(*v_proj_shape)
src_len = key_states.size(1) source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError( raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}" f" {attn_weights.size()}"
) )
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len): if attention_mask.size() != (batch_size, 1, target_len, source_len):
raise ValueError( raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
f" {attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -734,8 +731,8 @@ class ConditionalDetrAttention(nn.Module):
# make sure that attn_weights keeps its gradient. # make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped # In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following # twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
@ -743,15 +740,15 @@ class ConditionalDetrAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.v_head_dim): if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.v_head_dim)}, but is" f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.v_head_dim) attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.out_dim) attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
@ -887,7 +884,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*): position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys position embeddings that are added to the queries and keys
in the cross-attention layer. in the cross-attention layer.
@ -897,7 +895,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)` cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
@ -940,7 +939,7 @@ class ConditionalDetrDecoderLayer(nn.Module):
v = self.ca_v_proj(encoder_hidden_states) v = self.ca_v_proj(encoder_hidden_states)
batch_size, num_queries, n_model = q_content.shape batch_size, num_queries, n_model = q_content.shape
_, src_len, _ = k_content.shape _, source_len, _ = k_content.shape
k_pos = self.ca_kpos_proj(position_embeddings) k_pos = self.ca_kpos_proj(position_embeddings)
@ -958,9 +957,9 @@ class ConditionalDetrDecoderLayer(nn.Module):
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead) query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2) q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
k = k.view(batch_size, src_len, self.nhead, n_model // self.nhead) k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
k_pos = k_pos.view(batch_size, src_len, self.nhead, n_model // self.nhead) k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
k = torch.cat([k, k_pos], dim=3).view(batch_size, src_len, n_model * 2) k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
# Cross-Attention Block # Cross-Attention Block
cross_attn_weights = None cross_attn_weights = None
@ -1333,14 +1332,14 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
combined_attention_mask = None combined_attention_mask = None
if attention_mask is not None and combined_attention_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
encoder_attention_mask = _expand_mask( encoder_attention_mask = _expand_mask(
encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
) )
@ -2061,7 +2060,6 @@ def _expand(tensor, length: int):
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr # Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
class ConditionalDetrMaskHeadSmallConv(nn.Module): class ConditionalDetrMaskHeadSmallConv(nn.Module):
""" """
@ -2172,6 +2170,7 @@ class ConditionalDetrMHAttentionMap(nn.Module):
return weights return weights
# Copied from transformers.models.detr.modeling_detr.dice_loss
def dice_loss(inputs, targets, num_boxes): def dice_loss(inputs, targets, num_boxes):
""" """
Compute the DICE loss, similar to generalized IOU for masks Compute the DICE loss, similar to generalized IOU for masks
@ -2191,26 +2190,28 @@ def dice_loss(inputs, targets, num_boxes):
return loss.sum() / num_boxes return loss.sum() / num_boxes
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
""" """
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args: Args:
inputs: A float tensor of arbitrary shape. inputs (`torch.FloatTensor` of arbitrary shape):
The predictions for each example. The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary targets (`torch.FloatTensor` with the same shape as `inputs`)
classification label for each element in inputs (0 for the negative class and 1 for the positive A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
class). and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance alpha (`float`, *optional*, defaults to `0.25`):
positive vs negative examples. Default = -1 (no weighting). Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
gamma: Exponent of the modulating factor (1 - p_t) to gamma (`int`, *optional*, defaults to `2`):
balance easy vs hard examples. Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
Returns: Returns:
Loss tensor Loss tensor
""" """
prob = inputs.sigmoid() prob = inputs.sigmoid()
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# add modulating factor
p_t = prob * targets + (1 - prob) * (1 - targets) p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma) loss = ce_loss * ((1 - p_t) ** gamma)
@ -2221,26 +2222,24 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
return loss.mean(1).sum() / num_boxes return loss.mean(1).sum() / num_boxes
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
class ConditionalDetrLoss(nn.Module): class ConditionalDetrLoss(nn.Module):
""" """
This class computes the losses for ConditionalDetrForObjectDetection/ConditionalDetrForSegmentation. The process This class computes the losses for ConditionalDetrForObjectDetection/ConditionalDetrForSegmentation. The process
happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2)
we supervise each pair of matched ground-truth / prediction (supervise class and box). we supervise each pair of matched ground-truth / prediction (supervise class and box).
Args: Args:
matcher (`ConditionalDetrHungarianMatcher`): matcher (`ConditionalDetrHungarianMatcher`):
Module able to compute a matching between targets and proposals. Module able to compute a matching between targets and proposals.
num_classes (`int`): num_classes (`int`):
Number of object categories, omitting the special no-object category. Number of object categories, omitting the special no-object category.
focal_alpha (`float`): focal_alpha (`float`):
Alpha parmeter in focal loss. Alpha parameter in focal loss.
losses (`List[str]`): losses (`List[str]`):
List of all the losses to be applied. See `get_loss` for a list of all available losses. List of all the losses to be applied. See `get_loss` for a list of all available losses.
""" """
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.__init__
def __init__(self, matcher, num_classes, focal_alpha, losses): def __init__(self, matcher, num_classes, focal_alpha, losses):
super().__init__() super().__init__()
self.matcher = matcher self.matcher = matcher
@ -2248,7 +2247,7 @@ class ConditionalDetrLoss(nn.Module):
self.focal_alpha = focal_alpha self.focal_alpha = focal_alpha
self.losses = losses self.losses = losses
# removed logging parameter, which was part of the original implementation # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_labels
def loss_labels(self, outputs, targets, indices, num_boxes): def loss_labels(self, outputs, targets, indices, num_boxes):
""" """
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
@ -2256,33 +2255,34 @@ class ConditionalDetrLoss(nn.Module):
""" """
if "logits" not in outputs: if "logits" not in outputs:
raise KeyError("No logits were found in the outputs") raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"] source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full( target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros( target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], [source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
dtype=src_logits.dtype, dtype=source_logits.dtype,
layout=src_logits.layout, layout=source_logits.layout,
device=src_logits.device, device=source_logits.device,
) )
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1] target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = ( loss_ce = (
sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
* src_logits.shape[1] * source_logits.shape[1]
) )
losses = {"loss_ce": loss_ce} losses = {"loss_ce": loss_ce}
return losses return losses
@torch.no_grad() @torch.no_grad()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality
def loss_cardinality(self, outputs, targets, indices, num_boxes): def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" """
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
@ -2291,13 +2291,14 @@ class ConditionalDetrLoss(nn.Module):
""" """
logits = outputs["logits"] logits = outputs["logits"]
device = logits.device device = logits.device
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class) # Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float()) card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
losses = {"cardinality_error": card_err} losses = {"cardinality_error": card_err}
return losses return losses
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes
def loss_boxes(self, outputs, targets, indices, num_boxes): def loss_boxes(self, outputs, targets, indices, num_boxes):
""" """
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
@ -2307,21 +2308,22 @@ class ConditionalDetrLoss(nn.Module):
""" """
if "pred_boxes" not in outputs: if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs") raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx] source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none") loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {} losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag( loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
) )
losses["loss_giou"] = loss_giou.sum() / num_boxes losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_masks
def loss_masks(self, outputs, targets, indices, num_boxes): def loss_masks(self, outputs, targets, indices, num_boxes):
""" """
Compute the losses related to the masks: the focal loss and the dice loss. Compute the losses related to the masks: the focal loss and the dice loss.
@ -2331,42 +2333,45 @@ class ConditionalDetrLoss(nn.Module):
if "pred_masks" not in outputs: if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs") raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices) source_idx = self._get_source_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices) target_idx = self._get_target_permutation_idx(indices)
src_masks = outputs["pred_masks"] source_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx] source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets] masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss # TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks) target_masks = target_masks.to(source_masks)
target_masks = target_masks[tgt_idx] target_masks = target_masks[target_idx]
# upsample predictions to the target size # upsample predictions to the target size
src_masks = nn.functional.interpolate( source_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
) )
src_masks = src_masks[:, 0].flatten(1) source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1) target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape) target_masks = target_masks.view(source_masks.shape)
losses = { losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
} }
return losses return losses
def _get_src_permutation_idx(self, indices): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx
def _get_source_permutation_idx(self, indices):
# permute predictions following indices # permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices]) source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, src_idx return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx
def _get_target_permutation_idx(self, indices):
# permute targets following indices # permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, tgt_idx return batch_idx, target_idx
# Copied from transformers.models.detr.modeling_detr.DetrLoss.get_loss
def get_loss(self, loss, outputs, targets, indices, num_boxes): def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = { loss_map = {
"labels": self.loss_labels, "labels": self.loss_labels,
@ -2378,6 +2383,7 @@ class ConditionalDetrLoss(nn.Module):
raise ValueError(f"Loss {loss} not supported") raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes) return loss_map[loss](outputs, targets, indices, num_boxes)
# Copied from transformers.models.detr.modeling_detr.DetrLoss.forward
def forward(self, outputs, targets): def forward(self, outputs, targets):
""" """
This performs the loss computation. This performs the loss computation.
@ -2386,7 +2392,7 @@ class ConditionalDetrLoss(nn.Module):
outputs (`dict`, *optional*): outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format. Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*): targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc. losses applied, see each loss' doc.
""" """
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2394,7 +2400,7 @@ class ConditionalDetrLoss(nn.Module):
# Retrieve the matching between the outputs of the last layer and the targets # Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets) indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes # Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added # (Niels): comment out function below, distributed training to be added
@ -2445,6 +2451,7 @@ class ConditionalDetrMLPPredictionHead(nn.Module):
return x return x
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
class ConditionalDetrHungarianMatcher(nn.Module): class ConditionalDetrHungarianMatcher(nn.Module):
""" """
This class computes an assignment between the targets and the predictions of the network. This class computes an assignment between the targets and the predictions of the network.
@ -2500,21 +2507,21 @@ class ConditionalDetrHungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes # Also concat the target labels and boxes
tgt_ids = torch.cat([v["class_labels"] for v in targets]) target_ids = torch.cat([v["class_labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets]) target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. # Compute the classification cost.
alpha = 0.25 alpha = 0.25
gamma = 2.0 gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
# Compute the L1 cost between boxes # Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1) bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes # Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox)) giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
# Final cost matrix # Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
@ -2525,9 +2532,7 @@ class ConditionalDetrHungarianMatcher(nn.Module):
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py # Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t: Tensor) -> Tensor: def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point(): if t.is_floating_point():
@ -2595,9 +2600,18 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 # Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list): def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int] # type: (List[List[int]]) -> List[int]
maxes = the_list[0] maxes = the_list[0]

View File

@ -114,6 +114,8 @@ class DeformableDetrConfig(PretrainedConfig):
with_box_refine (`bool`, *optional*, defaults to `False`): with_box_refine (`bool`, *optional*, defaults to `False`):
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
based on the predictions from the previous layer. based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples: Examples:
@ -174,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
bbox_loss_coefficient=5, bbox_loss_coefficient=5,
giou_loss_coefficient=2, giou_loss_coefficient=2,
eos_coefficient=0.1, eos_coefficient=0.1,
focal_alpha=0.25,
**kwargs **kwargs
): ):
self.num_queries = num_queries self.num_queries = num_queries
@ -216,6 +219,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.bbox_loss_coefficient = bbox_loss_coefficient self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property @property

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1). (x_0, y_0, x_1, y_1).
""" """
x_c, y_c, w, h = x.unbind(-1) center_x, center_y, width, height = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)

View File

@ -35,7 +35,6 @@ from ...file_utils import (
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_torch_cuda_available, is_torch_cuda_available,
is_vision_available,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
@ -111,9 +110,6 @@ class MultiScaleDeformableAttentionFunction(Function):
if is_scipy_available(): if is_scipy_available():
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
if is_vision_available():
from transformers.models.detr.feature_extraction_detr import center_to_corners_format
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
@ -1952,7 +1948,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
criterion = DeformableDetrLoss( criterion = DeformableDetrLoss(
matcher=matcher, matcher=matcher,
num_classes=self.config.num_labels, num_classes=self.config.num_labels,
eos_coef=self.config.eos_coefficient, focal_alpha=self.config.focal_alpha,
losses=losses, losses=losses,
) )
criterion.to(self.device) criterion.to(self.device)
@ -2065,46 +2061,38 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
return loss.mean(1).sum() / num_boxes return loss.mean(1).sum() / num_boxes
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DeformableDetrLoss(nn.Module): class DeformableDetrLoss(nn.Module):
""" """
This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we This class computes the losses for `DeformableDetrForObjectDetection`. The process happens in two steps: 1) we
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
matched ground-truth / prediction (supervise class and box) matched ground-truth / prediction (supervise class and box).
Args:
matcher (`DeformableDetrHungarianMatcher`):
Module able to compute a matching between targets and proposals.
num_classes (`int`):
Number of object categories, omitting the special no-object category.
focal_alpha (`float`):
Alpha parameter in focal loss.
losses (`List[str]`):
List of all the losses to be applied. See `get_loss` for a list of all available losses.
""" """
def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25): def __init__(self, matcher, num_classes, focal_alpha, losses):
"""
Create the criterion.
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
Parameters:
matcher: module able to compute a matching between targets and proposals.
num_classes: number of object categories, omitting the special no-object category.
eos_coef: relative classification weight applied to the no-object category.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss.
"""
super().__init__() super().__init__()
self.matcher = matcher self.matcher = matcher
self.num_classes = num_classes self.num_classes = num_classes
self.losses = losses
self.focal_alpha = focal_alpha self.focal_alpha = focal_alpha
self.losses = losses
def loss_labels(self, outputs, targets, indices, num_boxes, log=True): # removed logging parameter, which was part of the original implementation
"""Classification loss (NLL) def loss_labels(self, outputs, targets, indices, num_boxes):
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
of dim [nb_target_boxes]
""" """
if "logits" not in outputs: if "logits" not in outputs:
raise ValueError("No logits were found in the outputs") raise KeyError("No logits were found in the outputs")
source_logits = outputs["logits"] source_logits = outputs["logits"]
idx = self._get_source_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
@ -2132,6 +2120,7 @@ class DeformableDetrLoss(nn.Module):
return losses return losses
@torch.no_grad() @torch.no_grad()
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
def loss_cardinality(self, outputs, targets, indices, num_boxes): def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" """
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
@ -2147,6 +2136,7 @@ class DeformableDetrLoss(nn.Module):
losses = {"cardinality_error": card_err} losses = {"cardinality_error": card_err}
return losses return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
def loss_boxes(self, outputs, targets, indices, num_boxes): def loss_boxes(self, outputs, targets, indices, num_boxes):
""" """
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
@ -2155,8 +2145,7 @@ class DeformableDetrLoss(nn.Module):
are expected in format (center_x, center_y, w, h), normalized by the image size. are expected in format (center_x, center_y, w, h), normalized by the image size.
""" """
if "pred_boxes" not in outputs: if "pred_boxes" not in outputs:
raise ValueError("No predicted boxes found in outputs") raise KeyError("No predicted boxes found in outputs")
idx = self._get_source_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx] source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
@ -2172,12 +2161,14 @@ class DeformableDetrLoss(nn.Module):
losses["loss_giou"] = loss_giou.sum() / num_boxes losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
def _get_source_permutation_idx(self, indices): def _get_source_permutation_idx(self, indices):
# permute predictions following indices # permute predictions following indices
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices]) source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx return batch_idx, source_idx
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
def _get_target_permutation_idx(self, indices): def _get_target_permutation_idx(self, indices):
# permute targets following indices # permute targets following indices
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
@ -2192,17 +2183,18 @@ class DeformableDetrLoss(nn.Module):
} }
if loss not in loss_map: if loss not in loss_map:
raise ValueError(f"Loss {loss} not supported") raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes) return loss_map[loss](outputs, targets, indices, num_boxes)
def forward(self, outputs, targets): def forward(self, outputs, targets):
""" """
This performs the loss computation. This performs the loss computation.
Parameters: Args:
outputs: dict of tensors, see the output specification of the model for the format outputs (`dict`, *optional*):
targets: list of dicts, such that len(targets) == batch_size. Dictionary of tensors, see the output specification of the model for the format.
The expected keys in each dict depends on the losses applied, see each loss' doc targets (`List[dict]`, *optional*):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
""" """
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2272,7 +2264,6 @@ class DeformableDetrMLPPredictionHead(nn.Module):
return x return x
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher
class DeformableDetrHungarianMatcher(nn.Module): class DeformableDetrHungarianMatcher(nn.Module):
""" """
This class computes an assignment between the targets and the predictions of the network. This class computes an assignment between the targets and the predictions of the network.
@ -2324,17 +2315,19 @@ class DeformableDetrHungarianMatcher(nn.Module):
batch_size, num_queries = outputs["logits"].shape[:2] batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch # We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes # Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets]) target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets]) target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL, # Compute the classification cost.
# but approximate it in 1 - proba[target class]. alpha = 0.25
# The 1 is a constant that doesn't change the matching, it can be ommitted. gamma = 2.0
class_cost = -out_prob[:, target_ids] neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
# Compute the L1 cost between boxes # Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
@ -2419,6 +2412,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis # Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list): def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int] # type: (List[List[int]]) -> List[int]

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1). (x_0, y_0, x_1, y_1).
""" """
x_c, y_c, w, h = x.unbind(-1) center_x, center_y, width, height = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)

View File

@ -33,7 +33,6 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
@ -44,9 +43,6 @@ from .configuration_detr import DetrConfig
if is_scipy_available(): if is_scipy_available():
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
if is_vision_available():
from .feature_extraction_detr import center_to_corners_format
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
@ -1964,16 +1960,16 @@ class DetrLoss(nn.Module):
""" """
if "logits" not in outputs: if "logits" not in outputs:
raise KeyError("No logits were found in the outputs") raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"] source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full( target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce} losses = {"loss_ce": loss_ce}
return losses return losses
@ -2003,17 +1999,17 @@ class DetrLoss(nn.Module):
""" """
if "pred_boxes" not in outputs: if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs") raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx] source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none") loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {} losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag( loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
) )
losses["loss_giou"] = loss_giou.sum() / num_boxes losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses return losses
@ -2027,41 +2023,41 @@ class DetrLoss(nn.Module):
if "pred_masks" not in outputs: if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs") raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices) source_idx = self._get_source_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices) target_idx = self._get_target_permutation_idx(indices)
src_masks = outputs["pred_masks"] source_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx] source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets] masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss # TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks) target_masks = target_masks.to(source_masks)
target_masks = target_masks[tgt_idx] target_masks = target_masks[target_idx]
# upsample predictions to the target size # upsample predictions to the target size
src_masks = nn.functional.interpolate( source_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
) )
src_masks = src_masks[:, 0].flatten(1) source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1) target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape) target_masks = target_masks.view(source_masks.shape)
losses = { losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
} }
return losses return losses
def _get_src_permutation_idx(self, indices): def _get_source_permutation_idx(self, indices):
# permute predictions following indices # permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices]) source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, src_idx return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices): def _get_target_permutation_idx(self, indices):
# permute targets following indices # permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, tgt_idx return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes): def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = { loss_map = {
@ -2082,7 +2078,7 @@ class DetrLoss(nn.Module):
outputs (`dict`, *optional*): outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format. Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*): targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc. losses applied, see each loss' doc.
""" """
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2288,6 +2284,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 # below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306

View File

@ -42,8 +42,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1). (x_0, y_0, x_1, y_1).
""" """
x_c, y_c, w, h = x.unbind(-1) center_x, center_y, width, height = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1) return torch.stack(b, dim=-1)

View File

@ -959,16 +959,16 @@ class YolosLoss(nn.Module):
""" """
if "logits" not in outputs: if "logits" not in outputs:
raise KeyError("No logits were found in the outputs") raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"] source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full( target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce} losses = {"loss_ce": loss_ce}
return losses return losses
@ -998,17 +998,17 @@ class YolosLoss(nn.Module):
""" """
if "pred_boxes" not in outputs: if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs") raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices) idx = self._get_source_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx] source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none") loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {} losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag( loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
) )
losses["loss_giou"] = loss_giou.sum() / num_boxes losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses return losses
@ -1022,41 +1022,41 @@ class YolosLoss(nn.Module):
if "pred_masks" not in outputs: if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs") raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices) source_idx = self._get_source_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices) target_idx = self._get_target_permutation_idx(indices)
src_masks = outputs["pred_masks"] source_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx] source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets] masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss # TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks) target_masks = target_masks.to(source_masks)
target_masks = target_masks[tgt_idx] target_masks = target_masks[target_idx]
# upsample predictions to the target size # upsample predictions to the target size
src_masks = nn.functional.interpolate( source_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
) )
src_masks = src_masks[:, 0].flatten(1) source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1) target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape) target_masks = target_masks.view(source_masks.shape)
losses = { losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
} }
return losses return losses
def _get_src_permutation_idx(self, indices): def _get_source_permutation_idx(self, indices):
# permute predictions following indices # permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices]) source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, src_idx return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices): def _get_target_permutation_idx(self, indices):
# permute targets following indices # permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, tgt_idx return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes): def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = { loss_map = {
@ -1077,7 +1077,7 @@ class YolosLoss(nn.Module):
outputs (`dict`, *optional*): outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format. Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*): targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc. losses applied, see each loss' doc.
""" """
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch CONDITIONAL_DETR model. """ """ Testing suite for the PyTorch Conditional DETR model. """
import inspect import inspect
@ -213,19 +213,19 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs) self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="CONDITIONAL_DETR does not use inputs_embeds") @unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="CONDITIONAL_DETR does not have a get_input_embeddings method") @unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@unittest.skip(reason="CONDITIONAL_DETR is not a generative model") @unittest.skip(reason="Conditional DETR is not a generative model")
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
pass pass
@unittest.skip(reason="CONDITIONAL_DETR does not use token embeddings") @unittest.skip(reason="Conditional DETR does not use token embeddings")
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
@ -474,7 +474,7 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape = torch.Size((1, 300, 256)) expected_shape = torch.Size((1, 300, 256))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape) self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]] [[0.4222, 0.7471, 0.8760], [0.6395, -0.2729, 0.7127], [-0.3090, 0.7642, 0.9529]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@ -495,48 +495,13 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels)) expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits) self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor( expected_slice_logits = torch.tensor(
[[-19.1194, -0.0893, -11.0154], [-17.3640, -1.8035, -14.0219], [-20.0461, -0.5837, -11.1060]] [[-10.4372, -5.7558, -8.6764], [-10.5410, -5.8704, -8.0590], [-10.6827, -6.3469, -8.3923]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor( expected_slice_boxes = torch.tensor(
[[0.4433, 0.5302, 0.8853], [0.5494, 0.2517, 0.0529], [0.4998, 0.5360, 0.9956]] [[0.7733, 0.6576, 0.4496], [0.5171, 0.1184, 0.9094], [0.8846, 0.5647, 0.2486]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
def test_inference_panoptic_segmentation_head(self):
model = ConditionalDetrForSegmentation.from_pretrained("microsoft/conditional-detr-resnet-50-panoptic").to(
torch_device
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
pixel_mask = encoding["pixel_mask"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values, pixel_mask)
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-18.1565, -1.7568, -13.5029], [-16.8888, -1.4138, -14.1028], [-17.5709, -2.5080, -11.8654]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.5344, 0.1789, 0.9285], [0.4420, 0.0572, 0.0875], [0.6630, 0.6887, 0.1017]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
expected_slice_masks = torch.tensor(
[[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-3))