mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Gradient Accumulation issue (#34191)
* quick fix * 3 losses * oups * fix * nits * check how it scales for special models * propagate for conditiona detr * propagate * propagate * propagate * fixes * propagate changes * update * fixup * nits * f string * fixes * more fixes * ? * nit * arg annoying f string * nits * grumble * update * nit * refactor * fix fetch tests * nit * nit * Update src/transformers/loss/loss_utils.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * update * nit * fixup * make pass * nits * port code to more models * fixup * ntis * arf * update * update * nits * update * fix * update * nits * fine * agjkfslga.jsdlkgjklas * nits * fix fx? * update * update * styel * fix imports * update * update * fixup to fix the torch fx? --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
parent
f51ac9e059
commit
c1c7e89620
@ -142,7 +142,9 @@ _import_structure = {
|
||||
"is_tensorboard_available",
|
||||
"is_wandb_available",
|
||||
],
|
||||
"loss": [],
|
||||
"modelcard": ["ModelCard"],
|
||||
# Losses
|
||||
"modeling_tf_pytorch_utils": [
|
||||
"convert_tf_weight_name_to_pt_weight_name",
|
||||
"load_pytorch_checkpoint_in_tf2_model",
|
||||
|
@ -184,6 +184,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
|
||||
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
|
||||
v5.
|
||||
loss_type (`str`, *optional*):
|
||||
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
|
||||
be automatically infered from the model architecture.
|
||||
"""
|
||||
|
||||
model_type: str = ""
|
||||
|
13
src/transformers/loss/__init__.py
Normal file
13
src/transformers/loss/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
178
src/transformers/loss/loss_deformable_detr.py
Normal file
178
src/transformers/loss/loss_deformable_detr.py
Normal file
@ -0,0 +1,178 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..image_transforms import center_to_corners_format
|
||||
from ..utils import is_scipy_available
|
||||
from .loss_for_object_detection import (
|
||||
HungarianMatcher,
|
||||
ImageLoss,
|
||||
_set_aux_loss,
|
||||
generalized_box_iou,
|
||||
sigmoid_focal_loss,
|
||||
)
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
class DeformableDetrHungarianMatcher(HungarianMatcher):
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Differences:
|
||||
- out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax
|
||||
- class_cost uses alpha and gamma
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
class DeformableDetrImageLoss(ImageLoss):
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
nn.Module.__init__(self)
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros(
|
||||
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
|
||||
dtype=source_logits.dtype,
|
||||
layout=source_logits.layout,
|
||||
device=source_logits.device,
|
||||
)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = (
|
||||
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* source_logits.shape[1]
|
||||
)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def DeformableDetrForSegmentationLoss(
|
||||
logits, labels, device, pred_boxes, pred_masks, config, outputs_class=None, outputs_coord=None, **kwargs
|
||||
):
|
||||
# First: create the matcher
|
||||
matcher = HungarianMatcher(class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality", "masks"]
|
||||
criterion = DeformableDetrImageLoss(
|
||||
matcher=matcher,
|
||||
num_classes=config.num_labels,
|
||||
focal_alpha=config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_loss["pred_masks"] = pred_masks
|
||||
|
||||
auxiliary_outputs = None
|
||||
if config.auxiliary_loss:
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
||||
weight_dict["loss_mask"] = config.mask_loss_coefficient
|
||||
weight_dict["loss_dice"] = config.dice_loss_coefficient
|
||||
if config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
return loss, loss_dict, auxiliary_outputs
|
||||
|
||||
|
||||
def DeformableDetrForObjectDetectionLoss(
|
||||
logits, labels, device, pred_boxes, config, outputs_class=None, outputs_coord=None, **kwargs
|
||||
):
|
||||
# First: create the matcher
|
||||
matcher = DeformableDetrHungarianMatcher(
|
||||
class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = DeformableDetrImageLoss(
|
||||
matcher=matcher,
|
||||
num_classes=config.num_labels,
|
||||
focal_alpha=config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
auxiliary_outputs = None
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if config.auxiliary_loss:
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
||||
if config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
return loss, loss_dict, auxiliary_outputs
|
562
src/transformers/loss/loss_for_object_detection.py
Normal file
562
src/transformers/loss/loss_for_object_detection.py
Normal file
@ -0,0 +1,562 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from ..utils import is_accelerate_available, is_scipy_available, is_vision_available, requires_backends
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class ImageLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
|
||||
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
||||
of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
||||
the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
||||
be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
||||
(`max_obj_id` + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`DetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
||||
[nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
|
||||
Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
||||
|
||||
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
|
||||
def ForSegmentationLoss(
|
||||
logits, labels, device, pred_boxes, pred_masks, config, outputs_class=None, outputs_coord=None, **kwargs
|
||||
):
|
||||
# First: create the matcher
|
||||
matcher = HungarianMatcher(class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality", "masks"]
|
||||
criterion = ImageLoss(
|
||||
matcher=matcher,
|
||||
num_classes=config.num_labels,
|
||||
eos_coef=config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_loss["pred_masks"] = pred_masks
|
||||
|
||||
auxiliary_outputs = None
|
||||
if config.auxiliary_loss:
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
||||
weight_dict["loss_mask"] = config.mask_loss_coefficient
|
||||
weight_dict["loss_dice"] = config.dice_loss_coefficient
|
||||
if config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
return loss, loss_dict, auxiliary_outputs
|
||||
|
||||
|
||||
def ForObjectDetectionLoss(
|
||||
logits, labels, device, pred_boxes, config, outputs_class=None, outputs_coord=None, **kwargs
|
||||
):
|
||||
# First: create the matcher
|
||||
matcher = HungarianMatcher(class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = ImageLoss(
|
||||
matcher=matcher,
|
||||
num_classes=config.num_labels,
|
||||
eos_coef=config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
auxiliary_outputs = None
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if config.auxiliary_loss:
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
||||
if config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
return loss, loss_dict, auxiliary_outputs
|
463
src/transformers/loss/loss_rt_detr.py
Normal file
463
src/transformers/loss/loss_rt_detr.py
Normal file
@ -0,0 +1,463 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_scipy_available, is_vision_available, requires_backends
|
||||
from .loss_for_object_detection import (
|
||||
_set_aux_loss,
|
||||
box_iou,
|
||||
dice_loss,
|
||||
generalized_box_iou,
|
||||
nested_tensor_from_tensor_list,
|
||||
sigmoid_focal_loss,
|
||||
)
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
|
||||
class RTDetrHungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
config: RTDetrConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = config.matcher_class_cost
|
||||
self.bbox_cost = config.matcher_bbox_cost
|
||||
self.giou_cost = config.matcher_giou_cost
|
||||
|
||||
self.use_focal_loss = config.use_focal_loss
|
||||
self.alpha = config.matcher_alpha
|
||||
self.gamma = config.matcher_gamma
|
||||
|
||||
if self.class_cost == self.bbox_cost == self.giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels
|
||||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
if self.use_focal_loss:
|
||||
out_prob = F.sigmoid(outputs["logits"].flatten(0, 1))
|
||||
out_prob = out_prob[:, target_ids]
|
||||
neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
# Compute the giou cost betwen boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
# Compute the final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
class RTDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for RTDetr. The process happens in two steps: 1) we compute hungarian assignment
|
||||
between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
|
||||
prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`DetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
weight_dict (`Dict`):
|
||||
Dictionary relating each loss with its weights. These losses are configured in RTDetrConf as
|
||||
`weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
alpha (`float`):
|
||||
Parameter alpha used to compute the focal loss.
|
||||
gamma (`float`):
|
||||
Parameter gamma used to compute the focal loss.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.matcher = RTDetrHungarianMatcher(config)
|
||||
self.num_classes = config.num_labels
|
||||
self.weight_dict = {
|
||||
"loss_vfl": config.weight_loss_vfl,
|
||||
"loss_bbox": config.weight_loss_bbox,
|
||||
"loss_giou": config.weight_loss_giou,
|
||||
}
|
||||
self.losses = ["vfl", "boxes"]
|
||||
self.eos_coef = config.eos_coefficient
|
||||
empty_weight = torch.ones(config.num_labels + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
self.alpha = config.focal_loss_alpha
|
||||
self.gamma = config.focal_loss_gamma
|
||||
|
||||
def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No predicted logits found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0)
|
||||
ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
ious = torch.diag(ious).detach()
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
|
||||
target_score_original = torch.zeros_like(target_classes, dtype=src_logits.dtype)
|
||||
target_score_original[idx] = ious.to(target_score_original.dtype)
|
||||
target_score = target_score_original.unsqueeze(-1) * target
|
||||
|
||||
pred_score = F.sigmoid(src_logits).detach()
|
||||
weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_vfl": loss}
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (NLL)
|
||||
targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not
|
||||
really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must
|
||||
contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in
|
||||
format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
losses = {}
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss. Targets dicts must contain the key
|
||||
"masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True):
|
||||
src_logits = outputs["logits"]
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
loss = F.binary_cross_entropy_with_logits(src_logits, target * 1.0, reduction="none")
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_bce": loss}
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits found in outputs")
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma)
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_focal": loss}
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
"bce": self.loss_labels_bce,
|
||||
"focal": self.loss_labels_focal,
|
||||
"vfl": self.loss_labels_vfl,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
@staticmethod
|
||||
def get_cdn_matched_indices(dn_meta, targets):
|
||||
dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
|
||||
num_gts = [len(t["class_labels"]) for t in targets]
|
||||
device = targets[0]["class_labels"].device
|
||||
|
||||
dn_match_indices = []
|
||||
for i, num_gt in enumerate(num_gts):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
|
||||
gt_idx = gt_idx.tile(dn_num_group)
|
||||
assert len(dn_positive_idx[i]) == len(gt_idx)
|
||||
dn_match_indices.append((dn_positive_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append(
|
||||
(
|
||||
torch.zeros(0, dtype=torch.int64, device=device),
|
||||
torch.zeros(0, dtype=torch.int64, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
return dn_match_indices
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
losses.update(l_dict)
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
# In case of cdn auxiliary losses. For rtdetr
|
||||
if "dn_auxiliary_outputs" in outputs:
|
||||
if "denoising_meta_values" not in outputs:
|
||||
raise ValueError(
|
||||
"The output must have the 'denoising_meta_values` key. Please, ensure that 'outputs' includes a 'denoising_meta_values' entry."
|
||||
)
|
||||
indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets)
|
||||
num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"]
|
||||
|
||||
for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]):
|
||||
# indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def RTDetrForObjectDetectionLoss(
|
||||
logits,
|
||||
labels,
|
||||
device,
|
||||
pred_boxes,
|
||||
config,
|
||||
outputs_class=None,
|
||||
outputs_coord=None,
|
||||
enc_topk_logits=None,
|
||||
enc_topk_bboxes=None,
|
||||
denoising_meta_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
criterion = RTDetrLoss(config)
|
||||
criterion.to(device)
|
||||
# Second: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if config.auxiliary_loss:
|
||||
if denoising_meta_values is not None:
|
||||
dn_out_coord, outputs_coord = torch.split(outputs_coord, denoising_meta_values["dn_num_split"], dim=2)
|
||||
dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
|
||||
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class[:, :-1].transpose(0, 1), outputs_coord[:, :-1].transpose(0, 1))
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
outputs_loss["auxiliary_outputs"].extend(_set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
|
||||
if denoising_meta_values is not None:
|
||||
outputs_loss["dn_auxiliary_outputs"] = _set_aux_loss(
|
||||
dn_out_class.transpose(0, 1), dn_out_coord.transpose(0, 1)
|
||||
)
|
||||
outputs_loss["denoising_meta_values"] = denoising_meta_values
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
|
||||
loss = sum(loss_dict.values())
|
||||
return loss, loss_dict, auxiliary_outputs
|
114
src/transformers/loss/loss_utils.py
Normal file
114
src/transformers/loss/loss_utils.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import BCEWithLogitsLoss, MSELoss
|
||||
|
||||
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
|
||||
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
|
||||
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
||||
|
||||
|
||||
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
def ForCausalLMLoss(
|
||||
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
||||
return loss
|
||||
|
||||
|
||||
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
||||
num_labels = config.num_labels
|
||||
if config.problem_type is None:
|
||||
if num_labels == 1:
|
||||
config.problem_type = "regression"
|
||||
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
config.problem_type = "single_label_classification"
|
||||
else:
|
||||
config.problem_type = "multi_label_classification"
|
||||
|
||||
if config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif config.problem_type == "single_label_classification":
|
||||
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
||||
elif config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
return loss
|
||||
|
||||
|
||||
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
|
||||
end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
return total_loss
|
||||
|
||||
|
||||
def ForTokenClassification(logits, labels, config, **kwargs):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.view(-1, config.num_labels)
|
||||
labels = labels.view(-1)
|
||||
logits = logits.float()
|
||||
# Flatten the tokens
|
||||
return fixed_cross_entropy(logits, labels, **kwargs)
|
||||
|
||||
|
||||
LOSS_MAPPING = {
|
||||
"ForCausalLM": ForCausalLMLoss,
|
||||
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
||||
"ForSequenceClassification": ForSequenceClassificationLoss,
|
||||
"ForTokenClassification": ForTokenClassification,
|
||||
"ForSegmentation": ForSegmentationLoss,
|
||||
"ForObjectDetection": ForObjectDetectionLoss,
|
||||
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
||||
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
}
|
@ -28,7 +28,7 @@ import tempfile
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from functools import lru_cache, partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from zipfile import is_zipfile
|
||||
@ -45,6 +45,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@ -4979,6 +4980,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return self.hf_quantizer.is_trainable
|
||||
|
||||
@property
|
||||
@lru_cache
|
||||
def loss_function(self):
|
||||
if getattr(self.config, "loss_type", None) is not None:
|
||||
loss_type = self.config.loss_type
|
||||
else:
|
||||
loss_type = self.__class__.__name__
|
||||
if loss_type not in LOSS_MAPPING:
|
||||
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
|
||||
loss_type = re.findall(loss_groups, self.__class__.__name__)
|
||||
if len(loss_type) > 0:
|
||||
loss_type = loss_type[0]
|
||||
else:
|
||||
loss_type = None
|
||||
if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
|
||||
logger.warning_once(
|
||||
f"`loss_type={loss_type}` was set in the config but it is unrecognised."
|
||||
f"Using the default loss: `ForCausalLMLoss`."
|
||||
)
|
||||
loss_type = "ForCausalLM"
|
||||
return LOSS_MAPPING[loss_type]
|
||||
|
||||
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
if PreTrainedModel.push_to_hub.__doc__ is not None:
|
||||
|
@ -28,7 +28,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1173,18 +1172,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -29,10 +29,7 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -41,18 +38,9 @@ from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_conditional_detr import ConditionalDetrConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_transforms import center_to_corners_format
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -1610,6 +1598,28 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->ConditionalDetr
|
||||
class ConditionalDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
|
||||
@ -1723,7 +1733,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
||||
|
||||
reference = outputs.reference_points if return_dict else outputs[-1]
|
||||
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
||||
outputs_coords = []
|
||||
|
||||
hs = sequence_output
|
||||
tmp = self.bbox_predictor(hs)
|
||||
tmp[..., :2] += reference_before_sigmoid
|
||||
@ -1732,47 +1742,20 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = ConditionalDetrHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = ConditionalDetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
focal_alpha=self.config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
outputs_coords = []
|
||||
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
||||
outputs_class = self.class_labels_classifier(intermediate)
|
||||
|
||||
for lvl in range(intermediate.shape[0]):
|
||||
tmp = self.bbox_predictor(intermediate[lvl])
|
||||
tmp[..., :2] += reference_before_sigmoid
|
||||
outputs_coord = tmp.sigmoid()
|
||||
outputs_coords.append(outputs_coord)
|
||||
outputs_coord = torch.stack(outputs_coords)
|
||||
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": self.config.cls_loss_coefficient, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -1977,43 +1960,14 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = ConditionalDetrHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality", "masks"]
|
||||
criterion = ConditionalDetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
focal_alpha=self.config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_loss["pred_masks"] = pred_masks
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
||||
outputs_class = self.conditional_detr.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self.conditional_detr._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
weight_dict["loss_mask"] = self.config.mask_loss_coefficient
|
||||
weight_dict["loss_dice"] = self.config.dice_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -2151,485 +2105,3 @@ class ConditionalDetrMHAttentionMap(nn.Module):
|
||||
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
||||
weights = self.dropout(weights)
|
||||
return weights
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
class ConditionalDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for ConditionalDetrForObjectDetection/ConditionalDetrForSegmentation. The process
|
||||
happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2)
|
||||
we supervise each pair of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`ConditionalDetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
focal_alpha (`float`):
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.__init__
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_labels
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros(
|
||||
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
|
||||
dtype=source_logits.dtype,
|
||||
layout=source_logits.layout,
|
||||
device=source_logits.device,
|
||||
)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = (
|
||||
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* source_logits.shape[1]
|
||||
)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_masks
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
|
||||
Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.get_loss
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.forward
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->ConditionalDetr
|
||||
class ConditionalDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
||||
class ConditionalDetrHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
@ -37,12 +37,9 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_ninja_available,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -86,23 +83,10 @@ def load_cuda_kernels():
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "DeformableDetrConfig"
|
||||
@ -1869,6 +1853,28 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
|
||||
class DeformableDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
|
||||
@ -1887,7 +1893,6 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
|
||||
# Deformable DETR encoder-decoder model
|
||||
self.model = DeformableDetrModel(config)
|
||||
|
||||
# Detection heads on top
|
||||
self.class_embed = nn.Linear(config.d_model, config.num_labels)
|
||||
self.bbox_embed = DeformableDetrMLPPredictionHead(
|
||||
@ -1922,14 +1927,6 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
@add_start_docstrings_to_model_forward(DEFORMABLE_DETR_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=DeformableDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -2034,41 +2031,9 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = DeformableDetrHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = DeformableDetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
focal_alpha=self.config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if self.config.auxiliary_loss:
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
if self.config.two_stage:
|
||||
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
||||
outputs_loss["enc_outputs"] = {"logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
||||
@ -2099,453 +2064,3 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
)
|
||||
|
||||
return dict_outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
class DeformableDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for `DeformableDetrForObjectDetection`. The process happens in two steps: 1) we
|
||||
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
|
||||
matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`DeformableDetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
focal_alpha (`float`):
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros(
|
||||
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
|
||||
dtype=source_logits.dtype,
|
||||
layout=source_logits.layout,
|
||||
device=source_logits.device,
|
||||
)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = (
|
||||
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* source_logits.shape[1]
|
||||
)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
if "enc_outputs" in outputs:
|
||||
enc_outputs = outputs["enc_outputs"]
|
||||
bin_targets = copy.deepcopy(targets)
|
||||
for bt in bin_targets:
|
||||
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
|
||||
indices = self.matcher(enc_outputs, bin_targets)
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
|
||||
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
|
||||
class DeformableDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeformableDetrHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
@ -29,10 +29,7 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -41,21 +38,10 @@ from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_detr import DetrConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "DetrConfig"
|
||||
@ -1343,6 +1329,28 @@ class DetrModel(DetrPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class DetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
||||
@ -1368,14 +1376,6 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1458,40 +1458,14 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = DetrHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = DetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
eos_coef=self.config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
||||
outputs_class = self.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -1542,7 +1516,6 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
||||
self.bbox_attention = DetrMHAttentionMap(
|
||||
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -1688,43 +1661,14 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = DetrHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality", "masks"]
|
||||
criterion = DetrLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
eos_coef=self.config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_loss["pred_masks"] = pred_masks
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
||||
outputs_class = self.detr.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self.detr._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
weight_dict["loss_mask"] = self.config.mask_loss_coefficient
|
||||
weight_dict["loss_dice"] = self.config.dice_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -1861,470 +1805,3 @@ class DetrMHAttentionMap(nn.Module):
|
||||
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
||||
weights = self.dropout(weights)
|
||||
return weights
|
||||
|
||||
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class DetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
|
||||
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
||||
of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
||||
the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
||||
be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
||||
(`max_obj_id` + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`DetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
||||
[nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
|
||||
Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class DetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
|
||||
class DetrHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
||||
|
||||
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
@ -25,7 +25,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1084,18 +1083,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1199,27 +1187,8 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1302,8 +1271,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -20,7 +20,6 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1003,18 +1002,7 @@ class GemmaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
@ -1065,18 +1064,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1258,27 +1246,8 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1361,8 +1330,7 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
@ -806,18 +805,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Grounding DINO model."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
@ -33,31 +32,19 @@ from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_vision_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...utils import is_accelerate_available, is_ninja_available, logging
|
||||
from ...utils import is_ninja_available, logging
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from ..auto import AutoModel
|
||||
from .configuration_grounding_dino import GroundingDinoConfig
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
@ -2488,436 +2475,6 @@ class GroundingDinoMLPPredictionHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->GroundingDino
|
||||
class GroundingDinoHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss with DeformableDetr->GroundingDino
|
||||
class GroundingDinoLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for `GroundingDinoForObjectDetection`. The process happens in two steps: 1) we
|
||||
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
|
||||
matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`GroundingDinoHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
focal_alpha (`float`):
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, focal_alpha, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros(
|
||||
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
|
||||
dtype=source_logits.dtype,
|
||||
layout=source_logits.layout,
|
||||
device=source_logits.device,
|
||||
)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = (
|
||||
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
|
||||
* source_logits.shape[1]
|
||||
)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
if "enc_outputs" in outputs:
|
||||
enc_outputs = outputs["enc_outputs"]
|
||||
bin_targets = copy.deepcopy(targets)
|
||||
for bt in bin_targets:
|
||||
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
|
||||
indices = self.matcher(enc_outputs, bin_targets)
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
|
||||
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Grounding DINO Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top,
|
||||
@ -3079,40 +2636,9 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = GroundingDinoHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = GroundingDinoLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
focal_alpha=self.config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if self.config.auxiliary_loss:
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
if self.config.two_stage:
|
||||
enc_outputs_coord = outputs[-1].sigmoid()
|
||||
outputs_loss["enc_outputs"] = {"logits": outputs[-2], "pred_boxes": enc_outputs_coord}
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
|
@ -26,7 +26,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
|
||||
@ -1543,18 +1542,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
@ -1729,27 +1717,8 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -20,7 +20,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -1436,27 +1436,8 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -24,7 +24,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1147,6 +1146,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1209,18 +1209,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1324,27 +1313,8 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1396,6 +1366,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1427,29 +1398,16 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -1526,8 +1484,7 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1220,27 +1220,8 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1324,8 +1305,7 @@ class MistralForTokenClassification(MistralPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -26,7 +26,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1327,18 +1326,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
@ -1458,27 +1446,8 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1562,8 +1531,7 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -21,7 +21,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
@ -1950,18 +1949,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -22,7 +22,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import Size, Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
@ -1084,18 +1083,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1200,27 +1188,8 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1273,6 +1242,7 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1304,29 +1274,16 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel):
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -1404,8 +1361,7 @@ class NemotronForTokenClassification(NemotronPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -26,7 +26,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1127,18 +1126,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -18,7 +18,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1291,18 +1290,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
|
@ -98,7 +98,7 @@ class Owlv2Output(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
# Copied from transformers.loss.loss_for_object_detection._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
@ -107,7 +107,7 @@ def _upcast(t: Tensor) -> Tensor:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
# Copied from transformers.loss.loss_for_object_detection.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
@ -124,7 +124,7 @@ def box_area(boxes: Tensor) -> Tensor:
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
# Copied from transformers.loss.loss_for_object_detection.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
@ -141,7 +141,7 @@ def box_iou(boxes1, boxes2):
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
@ -98,7 +98,7 @@ class OwlViTOutput(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
# Copied from transformers.loss.loss_for_object_detection._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
@ -107,7 +107,7 @@ def _upcast(t: Tensor) -> Tensor:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
# Copied from transformers.loss.loss_for_object_detection.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
@ -124,7 +124,7 @@ def box_area(boxes: Tensor) -> Tensor:
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
# Copied from transformers.loss.loss_for_object_detection.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
@ -141,7 +141,7 @@ def box_iou(boxes1, boxes2):
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1073,27 +1073,8 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1177,8 +1158,7 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1250,18 +1250,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1366,27 +1355,8 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + model_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -22,7 +22,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1300,18 +1300,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1455,27 +1444,8 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + model_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1474,18 +1473,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
@ -1644,27 +1632,8 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -1204,18 +1204,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1423,8 +1412,7 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -26,7 +26,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1393,18 +1392,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
@ -1524,27 +1512,8 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1628,8 +1597,7 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -37,19 +37,14 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_scipy_available,
|
||||
is_torch_cuda_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_rt_detr import RTDetrConfig
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MultiScaleDeformableAttention = None
|
||||
@ -1616,6 +1611,29 @@ def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class RTDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, input_dim, d_model, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [d_model] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
|
||||
@ -1950,588 +1968,6 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
class RTDetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for RTDetr. The process happens in two steps: 1) we compute hungarian assignment
|
||||
between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
|
||||
prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`DetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
weight_dict (`Dict`):
|
||||
Dictionary relating each loss with its weights. These losses are configured in RTDetrConf as
|
||||
`weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
alpha (`float`):
|
||||
Parameter alpha used to compute the focal loss.
|
||||
gamma (`float`):
|
||||
Parameter gamma used to compute the focal loss.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.matcher = RTDetrHungarianMatcher(config)
|
||||
self.num_classes = config.num_labels
|
||||
self.weight_dict = {
|
||||
"loss_vfl": config.weight_loss_vfl,
|
||||
"loss_bbox": config.weight_loss_bbox,
|
||||
"loss_giou": config.weight_loss_giou,
|
||||
}
|
||||
self.losses = ["vfl", "boxes"]
|
||||
self.eos_coef = config.eos_coefficient
|
||||
empty_weight = torch.ones(config.num_labels + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
self.alpha = config.focal_loss_alpha
|
||||
self.gamma = config.focal_loss_gamma
|
||||
|
||||
def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No predicted logits found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0)
|
||||
ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
ious = torch.diag(ious).detach()
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
|
||||
target_score_original = torch.zeros_like(target_classes, dtype=src_logits.dtype)
|
||||
target_score_original[idx] = ious.to(target_score_original.dtype)
|
||||
target_score = target_score_original.unsqueeze(-1) * target
|
||||
|
||||
pred_score = F.sigmoid(src_logits).detach()
|
||||
weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_vfl": loss}
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (NLL)
|
||||
targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not
|
||||
really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must
|
||||
contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in
|
||||
format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
src_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
losses = {}
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss. Targets dicts must contain the key
|
||||
"masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True):
|
||||
src_logits = outputs["logits"]
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
loss = F.binary_cross_entropy_with_logits(src_logits, target * 1.0, reduction="none")
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_bce": loss}
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits found in outputs")
|
||||
|
||||
src_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_original
|
||||
|
||||
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
||||
loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma)
|
||||
loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
|
||||
return {"loss_focal": loss}
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
"bce": self.loss_labels_bce,
|
||||
"focal": self.loss_labels_focal,
|
||||
"vfl": self.loss_labels_vfl,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
@staticmethod
|
||||
def get_cdn_matched_indices(dn_meta, targets):
|
||||
dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
|
||||
num_gts = [len(t["class_labels"]) for t in targets]
|
||||
device = targets[0]["class_labels"].device
|
||||
|
||||
dn_match_indices = []
|
||||
for i, num_gt in enumerate(num_gts):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
|
||||
gt_idx = gt_idx.tile(dn_num_group)
|
||||
assert len(dn_positive_idx[i]) == len(gt_idx)
|
||||
dn_match_indices.append((dn_positive_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append(
|
||||
(
|
||||
torch.zeros(0, dtype=torch.int64, device=device),
|
||||
torch.zeros(0, dtype=torch.int64, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
return dn_match_indices
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
losses.update(l_dict)
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
# In case of cdn auxiliary losses. For rtdetr
|
||||
if "dn_auxiliary_outputs" in outputs:
|
||||
if "denoising_meta_values" not in outputs:
|
||||
raise ValueError(
|
||||
"The output must have the 'denoising_meta_values` key. Please, ensure that 'outputs' includes a 'denoising_meta_values' entry."
|
||||
)
|
||||
indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets)
|
||||
num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"]
|
||||
|
||||
for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]):
|
||||
# indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs)
|
||||
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
||||
l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class RTDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, input_dim, d_model, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [d_model] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class RTDetrHungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
config: RTDetrConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = config.matcher_class_cost
|
||||
self.bbox_cost = config.matcher_bbox_cost
|
||||
self.giou_cost = config.matcher_giou_cost
|
||||
|
||||
self.use_focal_loss = config.use_focal_loss
|
||||
self.alpha = config.matcher_alpha
|
||||
self.gamma = config.matcher_gamma
|
||||
|
||||
if self.class_cost == self.bbox_cost == self.giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels
|
||||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
if self.use_focal_loss:
|
||||
out_prob = F.sigmoid(outputs["logits"].flatten(0, 1))
|
||||
out_prob = out_prob[:, target_ids]
|
||||
neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
|
||||
class_cost = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
# Compute the giou cost betwen boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
# Compute the final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
|
||||
@ -2673,39 +2109,26 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
||||
outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
|
||||
outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
|
||||
|
||||
if self.training and denoising_meta_values is not None:
|
||||
dn_out_coord, outputs_coord = torch.split(outputs_coord, denoising_meta_values["dn_num_split"], dim=2)
|
||||
dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
|
||||
|
||||
logits = outputs_class[:, -1]
|
||||
pred_boxes = outputs_coord[:, -1]
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
||||
if labels is not None:
|
||||
# First: create the criterion
|
||||
criterion = RTDetrLoss(self.config)
|
||||
criterion.to(self.device)
|
||||
# Second: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
if self.config.auxiliary_loss:
|
||||
if self.training and denoising_meta_values is not None:
|
||||
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
|
||||
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
|
||||
auxiliary_outputs = self._set_aux_loss(
|
||||
outputs_class[:, :-1].transpose(0, 1), outputs_coord[:, :-1].transpose(0, 1)
|
||||
)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
outputs_loss["auxiliary_outputs"].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
|
||||
if self.training and denoising_meta_values is not None:
|
||||
outputs_loss["dn_auxiliary_outputs"] = self._set_aux_loss(
|
||||
dn_out_class.transpose(0, 1), dn_out_coord.transpose(0, 1)
|
||||
)
|
||||
outputs_loss["denoising_meta_values"] = denoising_meta_values
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
|
||||
loss = sum(loss_dict.values())
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits,
|
||||
labels,
|
||||
self.device,
|
||||
pred_boxes,
|
||||
self.config,
|
||||
outputs_class,
|
||||
outputs_coord,
|
||||
enc_topk_logits=enc_topk_logits,
|
||||
enc_topk_bboxes=enc_topk_bboxes,
|
||||
denoising_meta_values=denoising_meta_values,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
@ -1349,27 +1349,8 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1453,8 +1434,7 @@ class StableLmForTokenClassification(StableLmPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
@ -1295,27 +1295,8 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1399,8 +1380,7 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
@ -29,10 +29,7 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -41,18 +38,9 @@ from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_table_transformer import TableTransformerConfig
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -1312,14 +1300,6 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@torch.jit.unused
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrForObjectDetection._set_aux_loss
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
@add_start_docstrings_to_model_forward(TABLE_TRANSFORMER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1398,40 +1378,14 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = TableTransformerHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = TableTransformerLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
eos_coef=self.config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
||||
outputs_class = self.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -1456,258 +1410,6 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss with Detr->TableTransformer,detr->table_transformer
|
||||
class TableTransformerLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for TableTransformerForObjectDetection/TableTransformerForSegmentation. The process happens in two steps: 1)
|
||||
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
||||
of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
A note on the `num_classes` argument (copied from original repo in table_transformer.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
||||
the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
||||
be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
||||
(`max_obj_id` + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/table_transformer/issues/108#issuecomment-650269223"
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`TableTransformerHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
||||
[nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
|
||||
Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->TableTransformer,detr->table_transformer
|
||||
class TableTransformerMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
@ -1728,200 +1430,3 @@ class TableTransformerMLPPredictionHead(nn.Module):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->TableTransformer
|
||||
class TableTransformerHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
@ -21,7 +21,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor, nn
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
@ -32,26 +32,12 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_scipy_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from .configuration_yolos import YolosConfig
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
if is_vision_available():
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import reduce
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
@ -728,6 +714,28 @@ class YolosPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos
|
||||
class YolosMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
|
||||
@ -837,40 +845,14 @@ class YolosForObjectDetection(YolosPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
# First: create the matcher
|
||||
matcher = YolosHungarianMatcher(
|
||||
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = YolosLoss(
|
||||
matcher=matcher,
|
||||
num_classes=self.config.num_labels,
|
||||
eos_coef=self.config.eos_coefficient,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(self.device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_class, outputs_coord = None, None
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
||||
outputs_class = self.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
|
||||
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
|
||||
if self.config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
@ -889,474 +871,3 @@ class YolosForObjectDetection(YolosPreTrainedModel):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.dice_loss
|
||||
def dice_loss(inputs, targets, num_boxes):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs (0 for the negative class and 1 for the positive
|
||||
class).
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
|
||||
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
alpha (`float`, *optional*, defaults to `0.25`):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to `2`):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.mean(1).sum() / num_boxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrLoss with Detr->Yolos
|
||||
class YolosLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for YolosForObjectDetection/YolosForSegmentation. The process happens in two steps: 1)
|
||||
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
||||
of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
||||
the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
||||
be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
||||
(`max_obj_id` + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`YolosHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses):
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer("empty_weight", empty_weight)
|
||||
|
||||
# removed logging parameter, which was part of the original implementation
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
||||
[nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
source_logits = outputs["logits"]
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
|
||||
)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
||||
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
device = logits.device
|
||||
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
||||
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
||||
losses = {"cardinality_error": card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
||||
|
||||
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
||||
are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if "pred_boxes" not in outputs:
|
||||
raise KeyError("No predicted boxes found in outputs")
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
source_boxes = outputs["pred_boxes"][idx]
|
||||
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
||||
|
||||
losses = {}
|
||||
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
|
||||
)
|
||||
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
|
||||
Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
||||
"""
|
||||
if "pred_masks" not in outputs:
|
||||
raise KeyError("No predicted masks found in outputs")
|
||||
|
||||
source_idx = self._get_source_permutation_idx(indices)
|
||||
target_idx = self._get_target_permutation_idx(indices)
|
||||
source_masks = outputs["pred_masks"]
|
||||
source_masks = source_masks[source_idx]
|
||||
masks = [t["masks"] for t in targets]
|
||||
# TODO use valid to mask invalid areas due to padding in loss
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(source_masks)
|
||||
target_masks = target_masks[target_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
source_masks = nn.functional.interpolate(
|
||||
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
source_masks = source_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(source_masks.shape)
|
||||
losses = {
|
||||
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
||||
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
||||
}
|
||||
return losses
|
||||
|
||||
def _get_source_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
||||
source_idx = torch.cat([source for (source, _) in indices])
|
||||
return batch_idx, source_idx
|
||||
|
||||
def _get_target_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
||||
target_idx = torch.cat([target for (_, target) in indices])
|
||||
return batch_idx, target_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
||||
loss_map = {
|
||||
"labels": self.loss_labels,
|
||||
"cardinality": self.loss_cardinality,
|
||||
"boxes": self.loss_boxes,
|
||||
"masks": self.loss_masks,
|
||||
}
|
||||
if loss not in loss_map:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if "auxiliary_outputs" in outputs:
|
||||
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
||||
indices = self.matcher(auxiliary_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == "masks":
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
||||
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos
|
||||
class YolosMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
||||
height and width of a bounding box w.r.t. an image.
|
||||
|
||||
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos
|
||||
class YolosHungarianMatcher(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions of the network.
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
super().__init__()
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
||||
raise ValueError("All costs of the Matcher can't be 0")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_ids = torch.cat([v["class_labels"] for v in targets])
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
class_cost = -out_prob[:, target_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._upcast
|
||||
def _upcast(t: Tensor) -> Tensor:
|
||||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||||
if t.is_floating_point():
|
||||
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||||
else:
|
||||
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_area
|
||||
def box_area(boxes: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||||
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||||
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||||
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||||
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr._max_by_axis
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.NestedTensor
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
if tensor_list[0].ndim == 3:
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
batch_size, num_channels, height, width = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("Only 3-dimensional tensors are supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
@ -1483,18 +1483,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -144,6 +144,57 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"initializer_range",
|
||||
"supported_aspect_ratios",
|
||||
],
|
||||
"ConditionalDetrConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"cls_loss_coefficient",
|
||||
"dice_loss_coefficient",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
"mask_loss_coefficient",
|
||||
],
|
||||
"DetrConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"dice_loss_coefficient",
|
||||
"eos_coefficient",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
"mask_loss_coefficient",
|
||||
],
|
||||
"GroundingDinoConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"RTDetrConfig": [
|
||||
"eos_coefficient",
|
||||
"focal_loss_alpha",
|
||||
"focal_loss_gamma",
|
||||
"matcher_alpha",
|
||||
"matcher_bbox_cost",
|
||||
"matcher_class_cost",
|
||||
"matcher_gamma",
|
||||
"matcher_giou_cost",
|
||||
"use_focal_loss",
|
||||
"weight_loss_bbox",
|
||||
"weight_loss_giou",
|
||||
"weight_loss_vfl",
|
||||
],
|
||||
"YolosConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"eos_coefficient",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user