mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix RT-DETR cache for generate_anchors (#31671)
* Fix cache and type conversion * Add test * Fixup * nit * [run slow] rt_detr * Fix test * Fixup * [run slow] rt_detr * Update src/transformers/models/rt_detr/modeling_rt_detr.py
This commit is contained in:
parent
534cbf8a5d
commit
b97521614a
@ -1656,7 +1656,11 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
param.requires_grad_(True)
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"):
|
||||
def generate_anchors(self, spatial_shapes=None, grid_size=0.05):
|
||||
# We always generate anchors in float32 to preserve equivalence between
|
||||
# dynamic and static anchor inference
|
||||
dtype = torch.float32
|
||||
|
||||
if spatial_shapes is None:
|
||||
spatial_shapes = [
|
||||
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
|
||||
@ -1674,7 +1678,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
|
||||
# define the valid range for anchor coordinates
|
||||
eps = 1e-2
|
||||
anchors = torch.concat(anchors, 1).to(device)
|
||||
anchors = torch.concat(anchors, 1)
|
||||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.inf)
|
||||
@ -1769,15 +1773,15 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
|
||||
# Prepare encoder inputs (by flattening)
|
||||
source_flatten = []
|
||||
spatial_shapes = []
|
||||
spatial_shapes_list = []
|
||||
for level, source in enumerate(sources):
|
||||
batch_size, num_channels, height, width = source.shape
|
||||
spatial_shape = (height, width)
|
||||
spatial_shapes.append(spatial_shape)
|
||||
spatial_shapes_list.append(spatial_shape)
|
||||
source = source.flatten(2).transpose(1, 2)
|
||||
source_flatten.append(source)
|
||||
source_flatten = torch.cat(source_flatten, 1)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
|
||||
# prepare denoising training
|
||||
@ -1805,9 +1809,14 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
|
||||
# prepare input for decoder
|
||||
if self.training or self.config.anchor_image_size is None:
|
||||
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
|
||||
# Pass spatial_shapes as tuple to make it hashable and make sure
|
||||
# lru_cache is working for generate_anchors()
|
||||
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
||||
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple)
|
||||
else:
|
||||
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
|
||||
anchors, valid_mask = self.anchors, self.valid_mask
|
||||
|
||||
anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
|
||||
|
||||
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
||||
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
@ -630,6 +631,48 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@parameterized.expand(["float32", "float16", "bfloat16"])
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str):
|
||||
torch_dtype = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}[torch_dtype_str]
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
h, w = inputs_dict["pixel_values"].shape[-2:]
|
||||
|
||||
# convert inputs to the desired dtype
|
||||
for key, tensor in inputs_dict.items():
|
||||
if tensor.dtype == torch.float32:
|
||||
inputs_dict[key] = tensor.to(torch_dtype)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_class(config).save_pretrained(tmpdirname)
|
||||
model_static = model_class.from_pretrained(
|
||||
tmpdirname, anchor_image_size=[h, w], device_map=torch_device, torch_dtype=torch_dtype
|
||||
).eval()
|
||||
model_dynamic = model_class.from_pretrained(
|
||||
tmpdirname, anchor_image_size=None, device_map=torch_device, torch_dtype=torch_dtype
|
||||
).eval()
|
||||
|
||||
self.assertIsNotNone(model_static.config.anchor_image_size)
|
||||
self.assertIsNone(model_dynamic.config.anchor_image_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class))
|
||||
outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4
|
||||
),
|
||||
f"Max diff: {(outputs_static.last_hidden_state - outputs_dynamic.last_hidden_state).abs().max()}",
|
||||
)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user