mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[GroundingDino] Fix grounding dino loss 🚨 (#31828)
* Starting to fix GroundingDinoLoss and GroundingDinoHungarianMatcher * More updates * More updates * fixed: GroundingDinoLoss * fixed: failing tests * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Addressed comments * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * add: cardinality loss and make box loss as copy from * change: default for reduction loss is sum * fix: vectorized generate fake box * fix copies * Addressed comments * addressed comments * addressed one-hot * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * Addressed comments * fixed test * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Starting to fix GroundingDinoLoss and GroundingDinoHungarianMatcher * More updates * More updates * fixed: GroundingDinoLoss * add: cardinality loss and make box loss as copy from * fix copies * Revert "Update tests/models/grounding_dino/test_modeling_grounding_dino.py" This reverts commit aa74c4c57c430e54cc74c414d6269edb65c73e83. * [run-slow] groundigdino * remove nestedtensor * [run-slow] groundig_dino * [run-slow] grounding_dino * [run-slow] grounding_dino * [run-slow] grounding_dino * check * check * add: enconder intermediate outputs to ImageLoss forward * add: GroundingDinoForObjectDetectionLoss in the loss directory * make style * fix the loss function * remove class_reduction since it sum is default * remove class_reduction * Update src/transformers/loss/loss_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * simple fix * Update src/transformers/loss/loss_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * minor fix * Update src/transformers/loss/loss_for_object_detection.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> Co-authored-by: sangbumchoi <danielsejong55@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
482d17be60
commit
222505c7e4
271
src/transformers/loss/loss_grounding_dino.py
Normal file
271
src/transformers/loss/loss_grounding_dino.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..image_transforms import center_to_corners_format
|
||||
from ..utils import is_scipy_available
|
||||
from .loss_for_object_detection import HungarianMatcher, ImageLoss, _set_aux_loss, generalized_box_iou
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
# Similar to the one used in `DeformableDetr` but we reduce with sum and normalize by num_boxes
|
||||
# instead of mean.
|
||||
def sigmoid_focal_loss(
|
||||
inputs: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
num_boxes: int,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2,
|
||||
):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs (`torch.FloatTensor` of arbitrary shape):
|
||||
The predictions for each example.
|
||||
targets (`torch.FloatTensor` with the same shape as `inputs`)
|
||||
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
|
||||
and 1 for the positive class).
|
||||
num_boxes (`int`):
|
||||
The total number of boxes in the batch.
|
||||
alpha (`float`, *optional*, defaults to 0.25):
|
||||
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
|
||||
gamma (`int`, *optional*, defaults to 2):
|
||||
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
# add modulating factor
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
return loss.sum() / num_boxes
|
||||
|
||||
|
||||
class GroundingDinoHungarianMatcher(HungarianMatcher):
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
* "label_maps": Tuple of tensors of dim [num_classes, hidden_dim].
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, hidden_dim]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
label_maps = outputs["label_maps"]
|
||||
|
||||
# First take the label map for each class in each batch and then concatenate them
|
||||
label_maps = torch.cat([label_map[target["class_labels"]] for label_map, target in zip(label_maps, targets)])
|
||||
# Normalize label maps based on number of tokens per class
|
||||
label_maps = label_maps / label_maps.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
target_bbox = torch.cat([v["boxes"] for v in targets])
|
||||
|
||||
# Compute the classification cost.
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
# Compute the classification cost by taking pos and neg cost in the appropriate index
|
||||
class_cost = (pos_cost_class - neg_cost_class) @ label_maps.t()
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
||||
|
||||
# Compute the giou cost between boxes
|
||||
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
||||
|
||||
|
||||
class GroundingDinoImageLoss(ImageLoss):
|
||||
"""
|
||||
This class computes the losses for `GroundingDinoForObjectDetection`. The process happens in two steps: 1) we
|
||||
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
|
||||
matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
Args:
|
||||
matcher (`GroundingDinoHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
focal_alpha (`float`):
|
||||
Alpha parameter in focal loss.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, focal_alpha, losses):
|
||||
nn.Module.__init__(self)
|
||||
self.matcher = matcher
|
||||
self.focal_alpha = focal_alpha
|
||||
self.losses = losses
|
||||
|
||||
def _get_target_classes_one_hot(self, outputs, targets, indices):
|
||||
"""
|
||||
Create one_hot based on the matching indices
|
||||
"""
|
||||
logits = outputs["logits"]
|
||||
# Add offsets to class_labels to select the correct label map
|
||||
class_labels = torch.cat(
|
||||
[
|
||||
target["class_labels"][J] + len(outputs["label_maps"][i]) if i > 0 else target["class_labels"][J]
|
||||
for i, (target, (_, J)) in enumerate(zip(targets, indices))
|
||||
]
|
||||
)
|
||||
label_maps = torch.cat(outputs["label_maps"], dim=0)
|
||||
|
||||
idx = self._get_source_permutation_idx(indices)
|
||||
target_classes_onehot = torch.zeros_like(logits, device=logits.device, dtype=torch.long)
|
||||
target_classes_onehot[idx] = label_maps[class_labels].to(torch.long)
|
||||
|
||||
return target_classes_onehot
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes):
|
||||
"""
|
||||
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
|
||||
of dim [nb_target_boxes]
|
||||
"""
|
||||
if "logits" not in outputs:
|
||||
raise KeyError("No logits were found in the outputs")
|
||||
if "text_mask" not in outputs:
|
||||
raise KeyError("No text_mask were found in the outputs")
|
||||
|
||||
target_classes_onehot = self._get_target_classes_one_hot(outputs, targets, indices)
|
||||
source_logits = outputs["logits"]
|
||||
text_mask = outputs["text_mask"]
|
||||
|
||||
# Select only valid logits
|
||||
source_logits = torch.masked_select(source_logits, text_mask)
|
||||
target_classes_onehot = torch.masked_select(target_classes_onehot, text_mask)
|
||||
|
||||
target_classes_onehot = target_classes_onehot.float()
|
||||
loss_ce = sigmoid_focal_loss(
|
||||
inputs=source_logits,
|
||||
targets=target_classes_onehot,
|
||||
num_boxes=num_boxes,
|
||||
alpha=self.focal_alpha,
|
||||
gamma=2,
|
||||
)
|
||||
|
||||
losses = {"loss_ce": loss_ce}
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def GroundingDinoForObjectDetectionLoss(
|
||||
logits,
|
||||
labels,
|
||||
device,
|
||||
pred_boxes,
|
||||
config,
|
||||
label_maps,
|
||||
text_mask,
|
||||
outputs_class=None,
|
||||
outputs_coord=None,
|
||||
encoder_logits=None,
|
||||
encoder_pred_boxes=None,
|
||||
):
|
||||
# First: create the matcher
|
||||
matcher = GroundingDinoHungarianMatcher(
|
||||
class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
|
||||
)
|
||||
# Second: create the criterion
|
||||
losses = ["labels", "boxes", "cardinality"]
|
||||
criterion = GroundingDinoImageLoss(
|
||||
matcher=matcher,
|
||||
focal_alpha=config.focal_alpha,
|
||||
losses=losses,
|
||||
)
|
||||
criterion.to(device)
|
||||
# Third: compute the losses, based on outputs and labels
|
||||
outputs_loss = {}
|
||||
outputs_loss["logits"] = logits
|
||||
outputs_loss["pred_boxes"] = pred_boxes
|
||||
outputs_loss["label_maps"] = label_maps
|
||||
outputs_loss["text_mask"] = text_mask
|
||||
|
||||
auxiliary_outputs = None
|
||||
if config.auxiliary_loss:
|
||||
auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
|
||||
for aux_output in auxiliary_outputs:
|
||||
aux_output["label_maps"] = label_maps
|
||||
aux_output["text_mask"] = text_mask
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
|
||||
if config.two_stage:
|
||||
encoder_outputs_loss = {
|
||||
"logits": encoder_logits,
|
||||
"pred_boxes": encoder_pred_boxes,
|
||||
"label_maps": label_maps,
|
||||
"text_mask": text_mask,
|
||||
}
|
||||
encoder_loss_dict = criterion(encoder_outputs_loss, labels)
|
||||
encoder_loss_dict = {k + "_enc": v for k, v in encoder_loss_dict.items()}
|
||||
loss_dict.update(encoder_loss_dict)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
weight_dict = {
|
||||
"loss_ce": 2.0,
|
||||
"loss_bbox": config.bbox_loss_coefficient,
|
||||
"loss_giou": config.giou_loss_coefficient,
|
||||
}
|
||||
|
||||
if config.two_stage:
|
||||
enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
|
||||
weight_dict.update(enc_weight_dict)
|
||||
|
||||
if config.auxiliary_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(config.decoder_layers - 1):
|
||||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
|
||||
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
||||
return loss, loss_dict, auxiliary_outputs
|
@ -18,6 +18,7 @@ from torch.nn import BCEWithLogitsLoss, MSELoss
|
||||
|
||||
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
|
||||
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
|
||||
from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
|
||||
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
||||
|
||||
|
||||
@ -129,7 +130,7 @@ LOSS_MAPPING = {
|
||||
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
|
||||
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
||||
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
|
@ -252,6 +252,10 @@ class GroundingDinoModelOutput(ModelOutput):
|
||||
background).
|
||||
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||
Logits of predicted bounding boxes coordinates in the first stage.
|
||||
encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`):
|
||||
Logits of top `config.num_queries` scoring bounding boxes in the first stage.
|
||||
encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||
Coordinates of top `config.num_queries` scoring bounding boxes in the first stage.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
@ -267,6 +271,8 @@ class GroundingDinoModelOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
enc_outputs_class: Optional[torch.FloatTensor] = None
|
||||
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
||||
encoder_logits: Optional[torch.FloatTensor] = None
|
||||
encoder_pred_boxes: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -331,6 +337,10 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
||||
background).
|
||||
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||
Logits of predicted bounding boxes coordinates in the first stage.
|
||||
encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`):
|
||||
Logits of top `config.num_queries` scoring bounding boxes in the first stage.
|
||||
encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||
Coordinates of top `config.num_queries` scoring bounding boxes in the first stage.
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Encoded candidate labels sequence. Used in processor to post process object detection result.
|
||||
"""
|
||||
@ -353,6 +363,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
enc_outputs_class: Optional[torch.FloatTensor] = None
|
||||
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
||||
encoder_logits: Optional[torch.FloatTensor] = None
|
||||
encoder_pred_boxes: Optional[torch.FloatTensor] = None
|
||||
input_ids: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
@ -2374,8 +2386,11 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
||||
)
|
||||
|
||||
# Fifth, prepare decoder inputs
|
||||
topk_proposals = None
|
||||
enc_outputs_class = None
|
||||
enc_outputs_coord_logits = None
|
||||
encoder_logits = None
|
||||
encoder_pred_boxes = None
|
||||
if self.config.two_stage:
|
||||
object_query_embedding, output_proposals = self.generate_encoder_output_proposals(
|
||||
encoder_outputs[0], ~mask_flatten, spatial_shapes
|
||||
@ -2408,6 +2423,10 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
||||
target = torch.gather(
|
||||
object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
|
||||
).detach()
|
||||
|
||||
# Set intermediate topk proposals (coords and class) for loss computation
|
||||
encoder_pred_boxes = reference_points
|
||||
encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask)
|
||||
else:
|
||||
target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid()
|
||||
@ -2430,7 +2449,16 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
|
||||
enc_outputs = tuple(
|
||||
value
|
||||
for value in [
|
||||
enc_outputs_class,
|
||||
enc_outputs_coord_logits,
|
||||
encoder_logits,
|
||||
encoder_pred_boxes,
|
||||
]
|
||||
if value is not None
|
||||
)
|
||||
tuple_outputs = (
|
||||
(decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs
|
||||
)
|
||||
@ -2451,6 +2479,8 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
enc_outputs_class=enc_outputs_class,
|
||||
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
||||
encoder_logits=encoder_logits,
|
||||
encoder_pred_boxes=encoder_pred_boxes,
|
||||
)
|
||||
|
||||
|
||||
@ -2476,6 +2506,73 @@ class GroundingDinoMLPPredictionHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def build_label_maps(logits: torch.FloatTensor, input_ids: torch.LongTensor) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Computes a mapping between tokens and their corresponding labels, where `num_labels` is determined by the number of classes in the input prompt.
|
||||
The function identifies segments of tokens between specific delimiter tokens and generates label maps for those segments.
|
||||
Args:
|
||||
logits (`torch.Tensor` of shape `(batch_size, seq_length, hidden_size)`):
|
||||
The output logits from the model, where `hidden_size` corresponds to the dimension of the model's output features.
|
||||
|
||||
input_ids (`torch.Tensor` of shape `(batch_size, seq_length)`):
|
||||
The input token IDs corresponding to the input prompt. For example, given the prompt "fish. shark.",
|
||||
`input_ids` might look like `[101, 3869, 1012, 11420, 1012, 102]` where each number corresponds to a token including special tokens.
|
||||
Returns:
|
||||
tuple: A tuple containing label maps for each instance in the batch.
|
||||
- label_maps (tuple of `torch.Tensor`):
|
||||
A tuple of tensors, where each tensor in the tuple corresponds to an instance in the batch. Each tensor
|
||||
has shape `(num_labels, hidden_size)` and contains binary values (0 or 1), where `1` indicates the tokens
|
||||
that are associated with a specific label (class) between delimiter tokens, and `0` elsewhere.
|
||||
Example:
|
||||
Given an input prompt "fish. shark." and corresponding `input_ids` as `[101, 3869, 1012, 11420, 1012, 102]`:
|
||||
- The function identifies the tokens for "fish" (IDs `[3869]`) and "shark" (IDs `[11420]`).
|
||||
- The function then constructs label maps for these tokens, where each label map indicates which tokens
|
||||
correspond to which label between the delimiter tokens (e.g., between the period `.`).
|
||||
- The output is a tuple of label maps, one for each instance in the batch.
|
||||
Note:
|
||||
- `SPECIAL_TOKENS` should be a predefined list of tokens that are considered special (e.g., `[CLS]`, `[SEP]`, etc.).
|
||||
"""
|
||||
max_seq_len = logits.shape[-1]
|
||||
# Add [PAD] token to the list of special tokens
|
||||
delimiter_tokens = torch.tensor(SPECIAL_TOKENS + [0], device=input_ids.device)
|
||||
|
||||
delimiter_token_masks = torch.isin(input_ids, delimiter_tokens)
|
||||
label_groups = torch.cumsum(delimiter_token_masks, dim=1) * (~delimiter_token_masks).to(torch.int32)
|
||||
|
||||
label_maps = ()
|
||||
|
||||
# Iterate over batch dimension as we can have different number of labels
|
||||
for label_group in label_groups:
|
||||
# `label_group` is a tensor of shape `(seq_len,)` with zeros for non-label tokens and integers for label tokens
|
||||
# label tokens with same integer value are part of the same label group
|
||||
|
||||
# Get unique labels and exclude 0 (i.e. non-label tokens)
|
||||
unique_labels = torch.unique(label_group)[1:, None]
|
||||
num_labels = unique_labels.shape[0]
|
||||
|
||||
# Create one-hot encoding for each label group
|
||||
label_map = label_group.unsqueeze(0).repeat(num_labels, 1)
|
||||
label_map = torch.where(label_map == unique_labels, 1, 0)
|
||||
|
||||
# Pad label_map to match `max_seq_len`
|
||||
label_map = F.pad(label_map, (0, max_seq_len - label_map.shape[1]), value=0)
|
||||
|
||||
label_maps += (label_map,)
|
||||
|
||||
return label_maps
|
||||
|
||||
|
||||
def build_text_mask(logits, attention_mask):
|
||||
"""
|
||||
Create text_mask based on the matching indices
|
||||
"""
|
||||
seq_len = attention_mask.shape[1]
|
||||
text_mask = torch.zeros_like(logits, device=logits.device, dtype=attention_mask.dtype)
|
||||
text_mask[:, :, :seq_len] = attention_mask[:, None, :]
|
||||
|
||||
return text_mask.bool()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Grounding DINO Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top,
|
||||
@ -2514,14 +2611,6 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
@add_start_docstrings_to_model_forward(GROUNDING_DINO_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=GroundingDinoObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -2648,8 +2737,20 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
label_maps = build_label_maps(logits, input_ids)
|
||||
text_mask = build_text_mask(logits, attention_mask)
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
logits,
|
||||
labels,
|
||||
self.device,
|
||||
pred_boxes,
|
||||
self.config,
|
||||
label_maps,
|
||||
text_mask,
|
||||
outputs_class=outputs_class,
|
||||
outputs_coord=outputs_coord,
|
||||
encoder_logits=outputs[-2],
|
||||
encoder_pred_boxes=outputs[-1],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -2677,6 +2778,8 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
init_reference_points=outputs.init_reference_points,
|
||||
enc_outputs_class=outputs.enc_outputs_class,
|
||||
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
||||
encoder_logits=outputs.encoder_logits,
|
||||
encoder_pred_boxes=outputs.encoder_pred_boxes,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
|
@ -20,6 +20,8 @@ import math
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
GroundingDinoConfig,
|
||||
SwinConfig,
|
||||
@ -28,6 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
@ -37,14 +40,14 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GroundingDinoForObjectDetection, GroundingDinoModel
|
||||
from transformers import GroundingDinoConfig, GroundingDinoForObjectDetection, GroundingDinoModel
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
@ -54,6 +57,39 @@ if is_vision_available():
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
def generate_fake_bounding_boxes(n_boxes):
|
||||
"""Generate bounding boxes in the format (center_x, center_y, width, height)"""
|
||||
# Validate the input
|
||||
if not isinstance(n_boxes, int):
|
||||
raise ValueError("n_boxes must be an integer")
|
||||
if n_boxes <= 0:
|
||||
raise ValueError("n_boxes must be a positive integer")
|
||||
|
||||
# Generate random bounding boxes in the format (center_x, center_y, width, height)
|
||||
bounding_boxes = torch.rand((n_boxes, 4))
|
||||
|
||||
# Extract the components
|
||||
center_x = bounding_boxes[:, 0]
|
||||
center_y = bounding_boxes[:, 1]
|
||||
width = bounding_boxes[:, 2]
|
||||
height = bounding_boxes[:, 3]
|
||||
|
||||
# Ensure width and height do not exceed bounds
|
||||
width = torch.min(width, torch.tensor(1.0))
|
||||
height = torch.min(height, torch.tensor(1.0))
|
||||
|
||||
# Ensure the bounding box stays within the normalized space
|
||||
center_x = torch.where(center_x - width / 2 < 0, width / 2, center_x)
|
||||
center_x = torch.where(center_x + width / 2 > 1, 1 - width / 2, center_x)
|
||||
center_y = torch.where(center_y - height / 2 < 0, height / 2, center_y)
|
||||
center_y = torch.where(center_y + height / 2 > 1, 1 - height / 2, center_y)
|
||||
|
||||
# Combine back into bounding boxes
|
||||
bounding_boxes = torch.stack([center_x, center_y, width, height], dim=1)
|
||||
|
||||
return bounding_boxes
|
||||
|
||||
|
||||
class GroundingDinoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -72,7 +108,7 @@ class GroundingDinoModelTester:
|
||||
num_channels=3,
|
||||
image_size=98,
|
||||
n_targets=8,
|
||||
num_labels=3,
|
||||
num_labels=2,
|
||||
num_feature_levels=4,
|
||||
encoder_n_points=2,
|
||||
decoder_n_points=6,
|
||||
@ -115,7 +151,11 @@ class GroundingDinoModelTester:
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
||||
|
||||
input_ids = ids_tensor([self.batch_size, self.max_text_len], self.num_labels)
|
||||
# When using `GroundingDino` the text input template is '{label1}. {label2}. {label3. ... {labelN}.'
|
||||
# Therefore to avoid errors when running tests with `labels` `input_ids` have to follow this structure.
|
||||
# Otherwise when running `build_label_maps` it will throw an error when trying to split the input_ids into segments.
|
||||
input_ids = torch.tensor([101, 3869, 1012, 11420, 3869, 1012, 102], device=torch_device)
|
||||
input_ids = input_ids.unsqueeze(0).expand(self.batch_size, -1)
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
@ -126,7 +166,7 @@ class GroundingDinoModelTester:
|
||||
target["class_labels"] = torch.randint(
|
||||
high=self.num_labels, size=(self.n_targets,), device=torch_device
|
||||
)
|
||||
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
|
||||
target["boxes"] = generate_fake_bounding_boxes(self.n_targets).to(torch_device)
|
||||
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
|
||||
labels.append(target)
|
||||
|
||||
@ -317,7 +357,7 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 10
|
||||
correct_outlen = 12
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
@ -677,6 +717,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(results["text_labels"], expected_labels)
|
||||
|
||||
@require_torch_accelerator
|
||||
@is_flaky()
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
@ -716,6 +757,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
torch.testing.assert_close(results_cpu["scores"], result_gpu["scores"].cpu(), rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(results_cpu["boxes"], result_gpu["boxes"].cpu(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_cross_attention_mask(self):
|
||||
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)
|
||||
|
||||
@ -740,4 +782,56 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.testing.assert_close(outputs1.logits, outputs_batched.logits[:1], rtol=1e-3, atol=1e-3)
|
||||
# For some reason 12 elements are > 1e-3, but the rest are fine
|
||||
torch.testing.assert_close(outputs2.logits, outputs_batched.logits[1:], rtol=1.8e-3, atol=1.8e-3)
|
||||
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))
|
||||
|
||||
def test_grounding_dino_loss(self):
|
||||
ds = load_dataset("EduardoPacheco/aquarium-sample", split="train")
|
||||
image_processor = self.default_processor.image_processor
|
||||
tokenizer = self.default_processor.tokenizer
|
||||
id2label = {0: "fish", 1: "jellyfish", 2: "penguins", 3: "sharks", 4: "puffins", 5: "stingrays", 6: "starfish"}
|
||||
prompt = ". ".join(id2label.values()) + "."
|
||||
|
||||
text_inputs = tokenizer([prompt, prompt], return_tensors="pt")
|
||||
image_inputs = image_processor(images=ds["image"], annotations=ds["annotations"], return_tensors="pt")
|
||||
|
||||
# Passing auxiliary_loss=True to compare with the expected loss
|
||||
model = GroundingDinoForObjectDetection.from_pretrained(
|
||||
"IDEA-Research/grounding-dino-tiny",
|
||||
auxiliary_loss=True,
|
||||
)
|
||||
# Interested in the loss only
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**text_inputs, **image_inputs)
|
||||
|
||||
# Loss differs by CPU and GPU, also this can be changed in future.
|
||||
expected_loss_dict = {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
}
|
||||
|
||||
expected_loss = torch.tensor(32482.2305)
|
||||
|
||||
for key in expected_loss_dict:
|
||||
self.assertTrue(torch.allclose(outputs.loss_dict[key], expected_loss_dict[key], atol=1e-3))
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-3))
|
||||
|
Loading…
Reference in New Issue
Block a user