mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Improve DETR models (#19644)
* Improve DETR models * Fix Deformable DETR loss and matcher * Fixup * Fix integration tests * Improve variable names * Apply suggestion * Fix copies * Fix DeformableDetrLoss * Make Conditional DETR copy from Deformable DETR * Copy from deformable detr's hungarian matcher * Fix bug
This commit is contained in:
parent
072dfdaee4
commit
90071fe42b
@ -23,7 +23,7 @@ The abstract from the paper is the following:
|
||||
|
||||
Tips:
|
||||
|
||||
- One can use the [`AutoFeatureExtractor`] API to prepare images (and optional targets) for the model. This will instantiate a [`DetrFeatureExtractor`] behind the scenes.
|
||||
- One can use [`DeformableDetrFeatureExtractor`] to prepare images (and optional targets) for the model.
|
||||
- Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/deformable_detr_architecture.png"
|
||||
|
@ -110,6 +110,8 @@ class ConditionalDetrConfig(PretrainedConfig):
|
||||
Relative weight of the generalized IoU loss in the object detection loss.
|
||||
eos_coefficient (`float`, *optional*, defaults to 0.1):
|
||||
Relative classification weight of the 'no-object' class in the object detection loss.
|
||||
focal_alpha (`float`, *optional*, defaults to 0.25):
|
||||
Alpha parameter in the focal loss.
|
||||
|
||||
Examples:
|
||||
|
||||
|
@ -44,8 +44,8 @@ def center_to_corners_format(x):
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
|
@ -33,7 +33,6 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -44,9 +43,6 @@ from .configuration_conditional_detr import ConditionalDetrConfig
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_conditional_detr import center_to_corners_format
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
@ -679,11 +675,11 @@ class ConditionalDetrAttention(nn.Module):
|
||||
|
||||
self.out_proj = nn.Linear(out_dim, out_dim, bias=bias)
|
||||
|
||||
def _qk_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def _v_shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
|
||||
def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||
return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -695,37 +691,38 @@ class ConditionalDetrAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
batch_size, target_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = hidden_states * self.scaling
|
||||
# get key, value proj
|
||||
key_states = self._qk_shape(key_states, -1, bsz)
|
||||
value_states = self._v_shape(value_states, -1, bsz)
|
||||
key_states = self._qk_shape(key_states, -1, batch_size)
|
||||
value_states = self._v_shape(value_states, -1, batch_size)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
v_proj_shape = (bsz * self.num_heads, -1, self.v_head_dim)
|
||||
query_states = self._qk_shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
||||
v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim)
|
||||
query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*v_proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
source_len = key_states.size(1)
|
||||
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
||||
f" {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
||||
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -734,8 +731,8 @@ class ConditionalDetrAttention(nn.Module):
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
||||
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
@ -743,15 +740,15 @@ class ConditionalDetrAttention(nn.Module):
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.v_head_dim):
|
||||
if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.v_head_dim)}, but is"
|
||||
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.v_head_dim)
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.out_dim)
|
||||
attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
@ -887,7 +884,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||
values.
|
||||
position_embeddings (`torch.FloatTensor`, *optional*):
|
||||
position embeddings that are added to the queries and keys
|
||||
in the cross-attention layer.
|
||||
@ -897,7 +895,8 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||
values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
@ -940,7 +939,7 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
||||
v = self.ca_v_proj(encoder_hidden_states)
|
||||
|
||||
batch_size, num_queries, n_model = q_content.shape
|
||||
_, src_len, _ = k_content.shape
|
||||
_, source_len, _ = k_content.shape
|
||||
|
||||
k_pos = self.ca_kpos_proj(position_embeddings)
|
||||
|
||||
@ -958,9 +957,9 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
||||
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
|
||||
query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
|
||||
q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
|
||||
k = k.view(batch_size, src_len, self.nhead, n_model // self.nhead)
|
||||
k_pos = k_pos.view(batch_size, src_len, self.nhead, n_model // self.nhead)
|
||||
k = torch.cat([k, k_pos], dim=3).view(batch_size, src_len, n_model * 2)
|
||||
k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
||||
k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
|
||||
k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_weights = None
|
||||
@ -1333,14 +1332,14 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
||||
combined_attention_mask = None
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
||||
encoder_attention_mask = _expand_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
|
||||
)
|
||||
@ -2061,7 +2060,6 @@ def _expand(tensor, length: int):
|
||||
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
|
||||
class ConditionalDetrMaskHeadSmallConv(nn.Module):
|
||||
"""
|
||||
@ -2172,6 +2170,7 @@ class ConditionalDetrMHAttentionMap(nn.Module):
|
||||
return weights
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
@ -2191,26 +2190,28 @@ def dice_loss(inputs, targets, num_boxes):
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
alpha: (optional) Weighting factor in range (0,1) to balance
|
||||
positive vs negative examples. Default = -1 (no weighting).
|
||||
gamma: Exponent of the modulating factor (1 - p_t) to
|
||||
balance easy vs hard examples.
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
@ -2221,26 +2222,24 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
|
||||
class ConditionalDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for ConditionalDetrForObjectDetection/ConditionalDetrForSegmentation. The process
|
||||
happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2)
|
||||
we supervise each pair of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`ConditionalDetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
focal_alpha (`float`):
|
||||
Alpha parmeter in focal loss.
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.__init__
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
@ -2248,7 +2247,7 @@ class ConditionalDetrLoss(nn.Module):
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_labels
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
@ -2256,33 +2255,34 @@ class ConditionalDetrLoss(nn.Module):
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
src_logits = outputs["logits"]
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros(
|
||||
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
|
||||
dtype=src_logits.dtype,
|
||||
layout=src_logits.layout,
|
||||
device=src_logits.device,
|
||||
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
|
||||
dtype=source_logits.dtype,
|
||||
layout=source_logits.layout,
|
||||
device=source_logits.device,
|
||||
)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = (
|
||||
sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* src_logits.shape[1]
|
||||
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* source_logits.shape[1]
|
||||
)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
@ -2291,13 +2291,14 @@ class ConditionalDetrLoss(nn.Module):
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
@ -2307,21 +2308,22 @@ class ConditionalDetrLoss(nn.Module):
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_masks
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
@ -2331,42 +2333,45 @@ class ConditionalDetrLoss(nn.Module):
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
src_idx = self._get_src_permutation_idx(indices)
|
||||
tgt_idx = self._get_tgt_permutation_idx(indices)
|
||||
src_masks = outputs["pred_masks"]
|
||||
src_masks = src_masks[src_idx]
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(src_masks)
|
||||
target_masks = target_masks[tgt_idx]
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
src_masks = nn.functional.interpolate(
|
||||
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
src_masks = src_masks[:, 0].flatten(1)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(src_masks.shape)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.get_loss
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
@ -2378,6 +2383,7 @@ class ConditionalDetrLoss(nn.Module):
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.forward
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
@ -2386,7 +2392,7 @@ class ConditionalDetrLoss(nn.Module):
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
@ -2394,7 +2400,7 @@ class ConditionalDetrLoss(nn.Module):
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
# (Niels): comment out function below, distributed training to be added
|
||||
@ -2445,6 +2451,7 @@ class ConditionalDetrMLPPredictionHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
||||
class ConditionalDetrHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
@ -2500,21 +2507,21 @@ class ConditionalDetrHungarianMatcher(nn.Module):
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
tgt_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1)
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox))
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
@ -2525,9 +2532,7 @@ class ConditionalDetrHungarianMatcher(nn.Module):
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
@ -2595,9 +2600,18 @@ def generalized_box_iou(boxes1, boxes2):
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
|
||||
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
|
||||
def center_to_corners_format(x):
|
||||
"""
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
|
@ -114,6 +114,8 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
with_box_refine (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
||||
based on the predictions from the previous layer.
|
||||
focal_alpha (`float`, *optional*, defaults to 0.25):
|
||||
Alpha parameter in the focal loss.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -174,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
bbox_loss_coefficient=5,
|
||||
giou_loss_coefficient=2,
|
||||
eos_coefficient=0.1,
|
||||
focal_alpha=0.25,
|
||||
**kwargs
|
||||
):
|
||||
self.num_queries = num_queries
|
||||
@ -216,6 +219,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
||||
self.bbox_loss_coefficient = bbox_loss_coefficient
|
||||
self.giou_loss_coefficient = giou_loss_coefficient
|
||||
self.eos_coefficient = eos_coefficient
|
||||
self.focal_alpha = focal_alpha
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
|
@ -44,8 +44,8 @@ def center_to_corners_format(x):
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
|
@ -35,7 +35,6 @@ from ...file_utils import (
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_vision_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
@ -111,9 +110,6 @@ class MultiScaleDeformableAttentionFunction(Function):
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.models.detr.feature_extraction_detr import center_to_corners_format
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
@ -1952,7 +1948,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
criterion = DeformableDetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
eos_coef=self.config.eos_coefficient,
|
||||
focal_alpha=self.config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
@ -2065,46 +2061,38 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class DeformableDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we
|
||||
This class computes the losses for `DeformableDetrForObjectDetection`. The process happens in two steps: 1) we
|
||||
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
|
||||
matched ground-truth / prediction (supervise class and box)
|
||||
matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`DeformableDetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
focal_alpha (`float`):
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25):
|
||||
"""
|
||||
Create the criterion.
|
||||
|
||||
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
|
||||
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
|
||||
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
|
||||
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
Parameters:
|
||||
matcher: module able to compute a matching between targets and proposals.
|
||||
num_classes: number of object categories, omitting the special no-object category.
|
||||
eos_coef: relative classification weight applied to the no-object category.
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
focal_alpha: alpha in Focal Loss.
|
||||
"""
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
super().__init__()
|
||||
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.losses = losses
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (NLL)
|
||||
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise ValueError("No logits were found in the outputs")
|
||||
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
@ -2132,6 +2120,7 @@ class DeformableDetrLoss(nn.Module):
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
@ -2147,6 +2136,7 @@ class DeformableDetrLoss(nn.Module):
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
@ -2155,8 +2145,7 @@ class DeformableDetrLoss(nn.Module):
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise ValueError("No predicted boxes found in outputs")
|
||||
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
@ -2172,12 +2161,14 @@ class DeformableDetrLoss(nn.Module):
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
@ -2192,17 +2183,18 @@ class DeformableDetrLoss(nn.Module):
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
@ -2272,7 +2264,6 @@ class DeformableDetrMLPPredictionHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher
|
||||
class DeformableDetrHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
@ -2324,17 +2315,19 @@ class DeformableDetrHungarianMatcher(nn.Module):
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
@ -2419,6 +2412,17 @@ def generalized_box_iou(boxes1, boxes2):
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
|
||||
def center_to_corners_format(x):
|
||||
"""
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
|
@ -44,8 +44,8 @@ def center_to_corners_format(x):
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
|
@ -33,7 +33,6 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -44,9 +43,6 @@ from .configuration_detr import DetrConfig
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_detr import center_to_corners_format
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
@ -1964,16 +1960,16 @@ class DetrLoss(nn.Module):
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
src_logits = outputs["logits"]
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
@ -2003,17 +1999,17 @@ class DetrLoss(nn.Module):
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
@ -2027,41 +2023,41 @@ class DetrLoss(nn.Module):
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
src_idx = self._get_src_permutation_idx(indices)
|
||||
tgt_idx = self._get_tgt_permutation_idx(indices)
|
||||
src_masks = outputs["pred_masks"]
|
||||
src_masks = src_masks[src_idx]
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(src_masks)
|
||||
target_masks = target_masks[tgt_idx]
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
src_masks = nn.functional.interpolate(
|
||||
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
src_masks = src_masks[:, 0].flatten(1)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(src_masks.shape)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
@ -2082,7 +2078,7 @@ class DetrLoss(nn.Module):
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
@ -2288,6 +2284,17 @@ def generalized_box_iou(boxes1, boxes2):
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
|
||||
def center_to_corners_format(x):
|
||||
"""
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
|
||||
|
||||
|
||||
|
@ -42,8 +42,8 @@ def center_to_corners_format(x):
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
center_x, center_y, width, height = x.unbind(-1)
|
||||
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
|
@ -959,16 +959,16 @@ class YolosLoss(nn.Module):
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
src_logits = outputs["logits"]
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
@ -998,17 +998,17 @@ class YolosLoss(nn.Module):
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
@ -1022,41 +1022,41 @@ class YolosLoss(nn.Module):
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
src_idx = self._get_src_permutation_idx(indices)
|
||||
tgt_idx = self._get_tgt_permutation_idx(indices)
|
||||
src_masks = outputs["pred_masks"]
|
||||
src_masks = src_masks[src_idx]
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(src_masks)
|
||||
target_masks = target_masks[tgt_idx]
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
src_masks = nn.functional.interpolate(
|
||||
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
src_masks = src_masks[:, 0].flatten(1)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(src_masks.shape)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
@ -1077,7 +1077,7 @@ class YolosLoss(nn.Module):
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch CONDITIONAL_DETR model. """
|
||||
""" Testing suite for the PyTorch Conditional DETR model. """
|
||||
|
||||
|
||||
import inspect
|
||||
@ -213,19 +213,19 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="CONDITIONAL_DETR does not use inputs_embeds")
|
||||
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CONDITIONAL_DETR does not have a get_input_embeddings method")
|
||||
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CONDITIONAL_DETR is not a generative model")
|
||||
@unittest.skip(reason="Conditional DETR is not a generative model")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CONDITIONAL_DETR does not use token embeddings")
|
||||
@unittest.skip(reason="Conditional DETR does not use token embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@ -474,7 +474,7 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 300, 256))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
|
||||
[[0.4222, 0.7471, 0.8760], [0.6395, -0.2729, 0.7127], [-0.3090, 0.7642, 0.9529]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@ -495,48 +495,13 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
||||
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape_logits)
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-19.1194, -0.0893, -11.0154], [-17.3640, -1.8035, -14.0219], [-20.0461, -0.5837, -11.1060]]
|
||||
[[-10.4372, -5.7558, -8.6764], [-10.5410, -5.8704, -8.0590], [-10.6827, -6.3469, -8.3923]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
|
||||
|
||||
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
|
||||
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.4433, 0.5302, 0.8853], [0.5494, 0.2517, 0.0529], [0.4998, 0.5360, 0.9956]]
|
||||
[[0.7733, 0.6576, 0.4496], [0.5171, 0.1184, 0.9094], [0.8846, 0.5647, 0.2486]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
def test_inference_panoptic_segmentation_head(self):
|
||||
model = ConditionalDetrForSegmentation.from_pretrained("microsoft/conditional-detr-resnet-50-panoptic").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
pixel_values = encoding["pixel_values"].to(torch_device)
|
||||
pixel_mask = encoding["pixel_mask"].to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, pixel_mask)
|
||||
|
||||
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape_logits)
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-18.1565, -1.7568, -13.5029], [-16.8888, -1.4138, -14.1028], [-17.5709, -2.5080, -11.8654]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
|
||||
|
||||
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
|
||||
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.5344, 0.1789, 0.9285], [0.4420, 0.0572, 0.0875], [0.6630, 0.6887, 0.1017]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
|
||||
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
|
||||
expected_slice_masks = torch.tensor(
|
||||
[[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-3))
|
||||
|
Loading…
Reference in New Issue
Block a user