This commit is contained in:
rabibastinj 2025-07-02 08:12:28 +08:00 committed by GitHub
commit 17ed2cf7af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 1 deletions

View File

@ -14,6 +14,8 @@
# limitations under the License.
"""PyTorch Grounding DINO model."""
import copy
import math
import warnings
from dataclasses import dataclass
@ -2430,7 +2432,9 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
shared_head = GroundingDinoMLPPredictionHead(
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
)
self.bbox_embed = nn.ModuleList([shared_head] * config.decoder_layers)
#self.bbox_embed = nn.ModuleList([shared_head] * config.decoder_layers)
self.bbox_embed = nn.ModuleList( [copy.deepcopy(shared_head) for _ in range(config.decoder_layers)])
else:
# each layer has its own head (implicit deep copy through a new instance)
self.bbox_embed = nn.ModuleList(

View File

@ -0,0 +1,21 @@
import pytest
import torch
from transformers.models.grounding_dino.modeling_grounding_dino import GroundingDinoModel
from transformers.models.grounding_dino.configuration_grounding_dino import GroundingDinoConfig
def test_bbox_embed_heads_are_independent_with_custom_config():
config = GroundingDinoConfig(
decoder_layers=2,
decoder_bbox_embed_share=True,
d_model=256,
num_queries=1,
)
model = GroundingDinoModel(config)
assert model.bbox_embed[0] is not model.bbox_embed[1]
original_weight = model.bbox_embed[1].layers[0].weight.clone()
with torch.no_grad():
model.bbox_embed[0].layers[0].weight.add_(10.0)
assert not torch.equal(model.bbox_embed[0].layers[0].weight, original_weight)
assert torch.equal(model.bbox_embed[1].layers[0].weight, original_weight)

View File

@ -0,0 +1,13 @@
import pytest
from transformers import AutoConfig
from transformers.models.grounding_dino.modeling_grounding_dino import GroundingDinoModel
def test_bbox_embed_instances_are_unique():
config = AutoConfig.for_model("grounding_dino")
config.decoder_layers = 2
config.decoder_bbox_embed_share = True
model = GroundingDinoModel(config)
# Ensure each layer has a unique bbox_embed head
assert model.bbox_embed[0] is not model.bbox_embed[1]