Improve DETR models (#19644)

* Improve DETR models

* Fix Deformable DETR loss and matcher

* Fixup

* Fix integration tests

* Improve variable names

* Apply suggestion

* Fix copies

* Fix DeformableDetrLoss

* Make Conditional DETR copy from Deformable DETR

* Copy from deformable detr's hungarian matcher

* Fix bug
This commit is contained in:
NielsRogge 2022-10-18 10:29:14 +02:00 committed by GitHub
parent 072dfdaee4
commit 90071fe42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 253 additions and 257 deletions

View File

@ -23,7 +23,7 @@ The abstract from the paper is the following:
Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images (and optional targets) for the model. This will instantiate a [`DetrFeatureExtractor`] behind the scenes.
- One can use [`DeformableDetrFeatureExtractor`] to prepare images (and optional targets) for the model.
- Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/deformable_detr_architecture.png"

View File

@ -110,6 +110,8 @@ class ConditionalDetrConfig(PretrainedConfig):
Relative weight of the generalized IoU loss in the object detection loss.
eos_coefficient (`float`, *optional*, defaults to 0.1):
Relative classification weight of the 'no-object' class in the object detection loss.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples:

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)

View File

@ -33,7 +33,6 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
@ -44,9 +43,6 @@ from .configuration_conditional_detr import ConditionalDetrConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from .feature_extraction_conditional_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
@ -679,11 +675,11 @@ class ConditionalDetrAttention(nn.Module):
self.out_proj = nn.Linear(out_dim, out_dim, bias=bias)
def _qk_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _v_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
def forward(
self,
@ -695,37 +691,38 @@ class ConditionalDetrAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
batch_size, target_len, _ = hidden_states.size()
# get query proj
query_states = hidden_states * self.scaling
# get key, value proj
key_states = self._qk_shape(key_states, -1, bsz)
value_states = self._v_shape(value_states, -1, bsz)
key_states = self._qk_shape(key_states, -1, batch_size)
value_states = self._v_shape(value_states, -1, batch_size)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
v_proj_shape = (bsz * self.num_heads, -1, self.v_head_dim)
query_states = self._qk_shape(query_states, tgt_len, bsz).view(*proj_shape)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim)
query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*v_proj_shape)
src_len = key_states.size(1)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
if attention_mask.size() != (batch_size, 1, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -734,8 +731,8 @@ class ConditionalDetrAttention(nn.Module):
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
@ -743,15 +740,15 @@ class ConditionalDetrAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.v_head_dim):
if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.v_head_dim)}, but is"
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.v_head_dim)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.out_dim)
attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
attn_output = self.out_proj(attn_output)
@ -887,7 +884,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the cross-attention layer.
@ -897,7 +895,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@ -940,7 +939,7 @@ class ConditionalDetrDecoderLayer(nn.Module):
v = self.ca_v_proj(encoder_hidden_states)
batch_size, num_queries, n_model = q_content.shape
_, src_len, _ = k_content.shape
_, source_len, _ = k_content.shape
k_pos = self.ca_kpos_proj(position_embeddings)
@ -958,9 +957,9 @@ class ConditionalDetrDecoderLayer(nn.Module):
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
k = k.view(batch_size, src_len, self.nhead, n_model // self.nhead)
k_pos = k_pos.view(batch_size, src_len, self.nhead, n_model // self.nhead)
k = torch.cat([k, k_pos], dim=3).view(batch_size, src_len, n_model * 2)
k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
# Cross-Attention Block
cross_attn_weights = None
@ -1333,14 +1332,14 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
combined_attention_mask = None
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
encoder_attention_mask = _expand_mask(
encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
)
@ -2061,7 +2060,6 @@ def _expand(tensor, length: int):
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
class ConditionalDetrMaskHeadSmallConv(nn.Module):
"""
@ -2172,6 +2170,7 @@ class ConditionalDetrMHAttentionMap(nn.Module):
return weights
# Copied from transformers.models.detr.modeling_detr.dice_loss
def dice_loss(inputs, targets, num_boxes):
"""
Compute the DICE loss, similar to generalized IOU for masks
@ -2191,26 +2190,28 @@ def dice_loss(inputs, targets, num_boxes):
return loss.sum() / num_boxes
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs (0 for the negative class and 1 for the positive
class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
inputs (`torch.FloatTensor` of arbitrary shape):
The predictions for each example.
targets (`torch.FloatTensor` with the same shape as `inputs`)
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
and 1 for the positive class).
alpha (`float`, *optional*, defaults to `0.25`):
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
gamma (`int`, *optional*, defaults to `2`):
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# add modulating factor
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
@ -2221,26 +2222,24 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
return loss.mean(1).sum() / num_boxes
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
class ConditionalDetrLoss(nn.Module):
"""
This class computes the losses for ConditionalDetrForObjectDetection/ConditionalDetrForSegmentation. The process
happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2)
we supervise each pair of matched ground-truth / prediction (supervise class and box).
Args:
matcher (`ConditionalDetrHungarianMatcher`):
Module able to compute a matching between targets and proposals.
num_classes (`int`):
Number of object categories, omitting the special no-object category.
focal_alpha (`float`):
Alpha parmeter in focal loss.
Alpha parameter in focal loss.
losses (`List[str]`):
List of all the losses to be applied. See `get_loss` for a list of all available losses.
"""
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.__init__
def __init__(self, matcher, num_classes, focal_alpha, losses):
super().__init__()
self.matcher = matcher
@ -2248,7 +2247,7 @@ class ConditionalDetrLoss(nn.Module):
self.focal_alpha = focal_alpha
self.losses = losses
# removed logging parameter, which was part of the original implementation
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_labels
def loss_labels(self, outputs, targets, indices, num_boxes):
"""
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
@ -2256,33 +2255,34 @@ class ConditionalDetrLoss(nn.Module):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
dtype=src_logits.dtype,
layout=src_logits.layout,
device=src_logits.device,
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
dtype=source_logits.dtype,
layout=source_logits.layout,
device=source_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
* src_logits.shape[1]
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
* source_logits.shape[1]
)
losses = {"loss_ce": loss_ce}
return losses
@torch.no_grad()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality
def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
@ -2291,13 +2291,14 @@ class ConditionalDetrLoss(nn.Module):
"""
logits = outputs["logits"]
device = logits.device
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
losses = {"cardinality_error": card_err}
return losses
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
@ -2307,21 +2308,22 @@ class ConditionalDetrLoss(nn.Module):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_masks
def loss_masks(self, outputs, targets, indices, num_boxes):
"""
Compute the losses related to the masks: the focal loss and the dice loss.
@ -2331,42 +2333,45 @@ class ConditionalDetrLoss(nn.Module):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]
# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
# Copied from transformers.models.detr.modeling_detr.DetrLoss.get_loss
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
"labels": self.loss_labels,
@ -2378,6 +2383,7 @@ class ConditionalDetrLoss(nn.Module):
raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes)
# Copied from transformers.models.detr.modeling_detr.DetrLoss.forward
def forward(self, outputs, targets):
"""
This performs the loss computation.
@ -2386,7 +2392,7 @@ class ConditionalDetrLoss(nn.Module):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2394,7 +2400,7 @@ class ConditionalDetrLoss(nn.Module):
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
@ -2445,6 +2451,7 @@ class ConditionalDetrMLPPredictionHead(nn.Module):
return x
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
class ConditionalDetrHungarianMatcher(nn.Module):
"""
This class computes an assignment between the targets and the predictions of the network.
@ -2500,21 +2507,21 @@ class ConditionalDetrHungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["class_labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1)
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox))
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
@ -2525,9 +2532,7 @@ class ConditionalDetrHungarianMatcher(nn.Module):
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
@ -2595,9 +2600,18 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]

View File

@ -114,6 +114,8 @@ class DeformableDetrConfig(PretrainedConfig):
with_box_refine (`bool`, *optional*, defaults to `False`):
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples:
@ -174,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
bbox_loss_coefficient=5,
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
**kwargs
):
self.num_queries = num_queries
@ -216,6 +219,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)

View File

@ -35,7 +35,6 @@ from ...file_utils import (
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
@ -111,9 +110,6 @@ class MultiScaleDeformableAttentionFunction(Function):
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from transformers.models.detr.feature_extraction_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
@ -1952,7 +1948,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
criterion = DeformableDetrLoss(
matcher=matcher,
num_classes=self.config.num_labels,
eos_coef=self.config.eos_coefficient,
focal_alpha=self.config.focal_alpha,
losses=losses,
)
criterion.to(self.device)
@ -2065,46 +2061,38 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
return loss.mean(1).sum() / num_boxes
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DeformableDetrLoss(nn.Module):
"""
This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we
This class computes the losses for `DeformableDetrForObjectDetection`. The process happens in two steps: 1) we
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
matched ground-truth / prediction (supervise class and box)
matched ground-truth / prediction (supervise class and box).
Args:
matcher (`DeformableDetrHungarianMatcher`):
Module able to compute a matching between targets and proposals.
num_classes (`int`):
Number of object categories, omitting the special no-object category.
focal_alpha (`float`):
Alpha parameter in focal loss.
losses (`List[str]`):
List of all the losses to be applied. See `get_loss` for a list of all available losses.
"""
def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25):
"""
Create the criterion.
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
Parameters:
matcher: module able to compute a matching between targets and proposals.
num_classes: number of object categories, omitting the special no-object category.
eos_coef: relative classification weight applied to the no-object category.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss.
"""
def __init__(self, matcher, num_classes, focal_alpha, losses):
super().__init__()
self.matcher = matcher
self.num_classes = num_classes
self.losses = losses
self.focal_alpha = focal_alpha
self.losses = losses
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
# removed logging parameter, which was part of the original implementation
def loss_labels(self, outputs, targets, indices, num_boxes):
"""
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
of dim [nb_target_boxes]
"""
if "logits" not in outputs:
raise ValueError("No logits were found in the outputs")
raise KeyError("No logits were found in the outputs")
source_logits = outputs["logits"]
idx = self._get_source_permutation_idx(indices)
@ -2132,6 +2120,7 @@ class DeformableDetrLoss(nn.Module):
return losses
@torch.no_grad()
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
@ -2147,6 +2136,7 @@ class DeformableDetrLoss(nn.Module):
losses = {"cardinality_error": card_err}
return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
@ -2155,8 +2145,7 @@ class DeformableDetrLoss(nn.Module):
are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
if "pred_boxes" not in outputs:
raise ValueError("No predicted boxes found in outputs")
raise KeyError("No predicted boxes found in outputs")
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
@ -2172,12 +2161,14 @@ class DeformableDetrLoss(nn.Module):
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
@ -2192,17 +2183,18 @@ class DeformableDetrLoss(nn.Module):
}
if loss not in loss_map:
raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes)
def forward(self, outputs, targets):
"""
This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
Args:
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2272,7 +2264,6 @@ class DeformableDetrMLPPredictionHead(nn.Module):
return x
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher
class DeformableDetrHungarianMatcher(nn.Module):
"""
This class computes an assignment between the targets and the predictions of the network.
@ -2324,17 +2315,19 @@ class DeformableDetrHungarianMatcher(nn.Module):
batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, target_ids]
# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
@ -2419,6 +2412,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]

View File

@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)

View File

@ -33,7 +33,6 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
@ -44,9 +43,6 @@ from .configuration_detr import DetrConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from .feature_extraction_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
@ -1964,16 +1960,16 @@ class DetrLoss(nn.Module):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}
return losses
@ -2003,17 +1999,17 @@ class DetrLoss(nn.Module):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
@ -2027,41 +2023,41 @@ class DetrLoss(nn.Module):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]
# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices):
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
@ -2082,7 +2078,7 @@ class DetrLoss(nn.Module):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@ -2288,6 +2284,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306

View File

@ -42,8 +42,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)

View File

@ -959,16 +959,16 @@ class YolosLoss(nn.Module):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}
return losses
@ -998,17 +998,17 @@ class YolosLoss(nn.Module):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
@ -1022,41 +1022,41 @@ class YolosLoss(nn.Module):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]
# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices):
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
@ -1077,7 +1077,7 @@ class YolosLoss(nn.Module):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch CONDITIONAL_DETR model. """
""" Testing suite for the PyTorch Conditional DETR model. """
import inspect
@ -213,19 +213,19 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="CONDITIONAL_DETR does not use inputs_embeds")
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR does not have a get_input_embeddings method")
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR is not a generative model")
@unittest.skip(reason="Conditional DETR is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR does not use token embeddings")
@unittest.skip(reason="Conditional DETR does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@ -474,7 +474,7 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape = torch.Size((1, 300, 256))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
[[0.4222, 0.7471, 0.8760], [0.6395, -0.2729, 0.7127], [-0.3090, 0.7642, 0.9529]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@ -495,48 +495,13 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-19.1194, -0.0893, -11.0154], [-17.3640, -1.8035, -14.0219], [-20.0461, -0.5837, -11.1060]]
[[-10.4372, -5.7558, -8.6764], [-10.5410, -5.8704, -8.0590], [-10.6827, -6.3469, -8.3923]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.4433, 0.5302, 0.8853], [0.5494, 0.2517, 0.0529], [0.4998, 0.5360, 0.9956]]
[[0.7733, 0.6576, 0.4496], [0.5171, 0.1184, 0.9094], [0.8846, 0.5647, 0.2486]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
def test_inference_panoptic_segmentation_head(self):
model = ConditionalDetrForSegmentation.from_pretrained("microsoft/conditional-detr-resnet-50-panoptic").to(
torch_device
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
pixel_mask = encoding["pixel_mask"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values, pixel_mask)
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-18.1565, -1.7568, -13.5029], [-16.8888, -1.4138, -14.1028], [-17.5709, -2.5080, -11.8654]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.5344, 0.1789, 0.9285], [0.4420, 0.0572, 0.0875], [0.6630, 0.6887, 0.1017]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
expected_slice_masks = torch.tensor(
[[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-3))