Fix RT-DETR inference with float16 and bfloat16 (#31639)

* [run_slow] rt_detr

* Fix positional embeddings and anchors dtypes

* [run slow] rt_detr

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixup

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Pavel Iakubovskii 2024-06-26 17:50:10 +01:00 committed by GitHub
parent 3f93fd0694
commit b1ec745475
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 4 deletions

View File

@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module):
if self.training or self.eval_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
).to(src_flatten.device)
).to(src_flatten.device, src_flatten.dtype)
else:
pos_embed = None
@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel):
batch_size = len(source_flatten)
device = source_flatten.device
dtype = source_flatten.dtype
# prepare input for decoder
if self.training or self.config.anchor_image_size is None:
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device)
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
else:
anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device)
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
memory = valid_mask.to(source_flatten.dtype) * source_flatten

View File

@ -18,6 +18,8 @@ import inspect
import math
import unittest
from parameterized import parameterized
from transformers import (
RTDetrConfig,
RTDetrImageProcessor,
@ -25,7 +27,7 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_torch, require_vision, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu
@slow
def test_inference_with_different_dtypes(self, torch_dtype_str):
torch_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[torch_dtype_str]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device).to(torch_dtype)
model.eval()
for key, tensor in inputs_dict.items():
if tensor.dtype == torch.float32:
inputs_dict[key] = tensor.to(torch_dtype)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
TOLERANCE = 1e-4