mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix RT-DETR inference with float16 and bfloat16 (#31639)
* [run_slow] rt_detr * Fix positional embeddings and anchors dtypes * [run slow] rt_detr * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
3f93fd0694
commit
b1ec745475
@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module):
|
||||
if self.training or self.eval_size is None:
|
||||
pos_embed = self.build_2d_sincos_position_embedding(
|
||||
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
|
||||
).to(src_flatten.device)
|
||||
).to(src_flatten.device, src_flatten.dtype)
|
||||
else:
|
||||
pos_embed = None
|
||||
|
||||
@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
|
||||
batch_size = len(source_flatten)
|
||||
device = source_flatten.device
|
||||
dtype = source_flatten.dtype
|
||||
|
||||
# prepare input for decoder
|
||||
if self.training or self.config.anchor_image_size is None:
|
||||
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device)
|
||||
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
|
||||
else:
|
||||
anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device)
|
||||
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
|
||||
|
||||
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
||||
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
||||
|
@ -18,6 +18,8 @@ import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
RTDetrConfig,
|
||||
RTDetrImageProcessor,
|
||||
@ -25,7 +27,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_vision, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@parameterized.expand(["float32", "float16", "bfloat16"])
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_inference_with_different_dtypes(self, torch_dtype_str):
|
||||
torch_dtype = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}[torch_dtype_str]
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device).to(torch_dtype)
|
||||
model.eval()
|
||||
for key, tensor in inputs_dict.items():
|
||||
if tensor.dtype == torch.float32:
|
||||
inputs_dict[key] = tensor.to(torch_dtype)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user