Use deformable_detr kernel from the Hub (#36853)

* Use `deformable_detr` kernel from the Hub

Remove the `deformable_detr` kernel from `kernels/` and use the
pre-built kernel from the Hub instead.

* Add license header

* Add `kernels` as an extra `hub-kernels`

Also add it to `testing`, so that the kernel replacement gets tested
when using CUDA in CI.
This commit is contained in:
Daniël de Kok 2025-03-21 13:08:47 +01:00 committed by GitHub
parent 2638d54e78
commit f94b0c59f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 409 additions and 3838 deletions

View File

@ -129,6 +129,7 @@ _deps = [
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
"kernels>=0.3.2,<0.4",
"librosa",
"natten>=0.14.6,<0.15.0",
"nltk<=3.8.1",
@ -301,8 +302,9 @@ extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]")
extras["sigopt"] = deps_list("sigopt")
extras["hub-kernels"] = deps_list("kernels")
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")

View File

@ -35,6 +35,7 @@ deps = {
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
"kernels": "kernels>=0.3.2,<0.4",
"librosa": "librosa",
"natten": "natten>=0.14.6,<0.15.0",
"nltk": "nltk<=3.8.1",

View File

@ -70,6 +70,12 @@ _import_structure = {
"replace_with_higgs_linear",
],
"hqq": ["prepare_for_hqq_linear"],
"hub_kernels": [
"LayerRepository",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
@ -198,6 +204,12 @@ if TYPE_CHECKING:
)
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
from .hqq import prepare_for_hqq_linear
from .hub_kernels import (
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,

View File

@ -0,0 +1,73 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
try:
from kernels import (
Device,
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_hub_kernels_available = True
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",
layer_name="MultiScaleDeformableAttention",
)
}
}
register_kernel_mapping(_KERNEL_MAPPING)
except ImportError:
# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
def decorator(cls):
return cls
return decorator
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
)
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
_hub_kernels_available = False
def is_hub_kernels_available():
return _hub_kernels_available
__all__ = [
"LayerRepository",
"is_hub_kernels_available",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
]

View File

@ -1,40 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}

View File

@ -1,32 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);

View File

@ -1,159 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "cuda/ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
at::DeviceGuard guard(value.device());
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data_ptr<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
at::DeviceGuard guard(value.device());
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}

View File

@ -1,46 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
at::Tensor ms_deform_attn_cuda_forward_bf16(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);

View File

@ -1,61 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

View File

@ -1,16 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}

View File

@ -1,50 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loading of Deformable DETR's CUDA kernels"""
import os
from pathlib import Path
def load_cuda_kernels():
from torch.utils.cpp_extension import load
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
import MultiScaleDeformableAttention as MSDA
return MSDA

View File

@ -16,19 +16,16 @@
import copy
import math
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
@ -37,10 +34,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_timm_available,
is_torch_cuda_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
requires_backends,
@ -51,38 +45,6 @@ from .configuration_deformable_detr import DeformableDetrConfig
logger = logging.get_logger(__name__)
MultiScaleDeformableAttention = None
def load_cuda_kernels():
from torch.utils.cpp_extension import load
global MultiScaleDeformableAttention
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
if is_timm_available():
from timm import create_model
@ -94,52 +56,59 @@ _CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
class MultiScaleDeformableAttention(nn.Module):
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id]
.flatten(2)
.transpose(1, 2)
.reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
return output.transpose(1, 2).contiguous()
@dataclass
@ -564,48 +533,6 @@ def build_position_encoding(config):
return position_embedding
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
@ -614,12 +541,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.attn = MultiScaleDeformableAttention()
if config.d_model % num_heads != 0:
raise ValueError(
@ -706,27 +628,16 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.attn(
value,
spatial_shapes,
spatial_shapes_list,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output, attention_weights
@ -834,7 +745,11 @@ class DeformableDetrMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
if attn_output.size() != (
batch_size * self.num_heads,
target_len,
self.head_dim,
):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
@ -854,7 +769,9 @@ class DeformableDetrEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
config,
num_heads=config.encoder_attention_heads,
n_points=config.encoder_n_points,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
@ -1054,7 +971,11 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
_no_split_modules = [
r"DeformableDetrConvEncoder",
r"DeformableDetrEncoderLayer",
r"DeformableDetrDecoderLayer",
]
def _init_weights(self, module):
std = self.config.init_std
@ -1299,7 +1220,9 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
@ -1525,7 +1448,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
for _ in range(config.num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
nn.Conv2d(
in_channels,
config.d_model,
kernel_size=3,
stride=2,
padding=1,
),
nn.GroupNorm(32, config.d_model),
)
)
@ -1535,7 +1464,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
nn.Conv2d(
backbone.intermediate_channel_sizes[-1],
config.d_model,
kernel_size=1,
),
nn.GroupNorm(32, config.d_model),
)
]
@ -1625,8 +1558,20 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = meshgrid(
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
torch.linspace(
0,
height - 1,
height,
dtype=enc_output.dtype,
device=enc_output.device,
),
torch.linspace(
0,
width - 1,
width,
dtype=enc_output.dtype,
device=enc_output.device,
),
indexing="ij",
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
@ -1802,7 +1747,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
topk = self.config.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_logits = torch.gather(
enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
enc_outputs_coord_logits,
1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
)
topk_coords_logits = topk_coords_logits.detach()
@ -1897,7 +1844,10 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# Detection heads on top
self.class_embed = nn.Linear(config.d_model, config.num_labels)
self.bbox_embed = DeformableDetrMLPPredictionHead(
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
input_dim=config.d_model,
hidden_dim=config.d_model,
output_dim=4,
num_layers=3,
)
prior_prob = 0.01
@ -2033,7 +1983,13 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
loss, loss_dict, auxiliary_outputs = self.loss_function(
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
logits,
labels,
self.device,
pred_boxes,
self.config,
outputs_class,
outputs_coord,
)
if not return_dict:
if auxiliary_outputs is not None:

View File

@ -15,17 +15,13 @@
"""PyTorch Grounding DINO model."""
import math
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN
from ...file_utils import (
@ -33,13 +29,13 @@ from ...file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_timm_available,
is_torch_cuda_available,
replace_return_docstrings,
requires_backends,
)
from ...integrations import use_kernel_forward_from_hub
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ...utils import logging
from ...utils.backbone_utils import load_backbone
from ..auto import AutoModel
from .configuration_grounding_dino import GroundingDinoConfig
@ -49,97 +45,68 @@ if is_timm_available():
from timm import create_model
logger = logging.get_logger(__name__)
MultiScaleDeformableAttention = None
# Copied from models.deformable_detr.load_cuda_kernels
def load_cuda_kernels():
from torch.utils.cpp_extension import load
global MultiScaleDeformableAttention
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GroundingDinoConfig"
_CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny"
# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id]
.flatten(2)
.transpose(1, 2)
.reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
@dataclass
class GroundingDinoDecoderOutput(ModelOutput):
"""
@ -583,49 +550,6 @@ def build_position_encoding(config):
return position_embedding
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->GroundingDino, Deformable DETR->Grounding DINO
class GroundingDinoMultiscaleDeformableAttention(nn.Module):
"""
@ -635,12 +559,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int):
super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.attn = MultiScaleDeformableAttention()
if config.d_model % num_heads != 0:
raise ValueError(
@ -727,23 +646,16 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.attn(
value,
spatial_shapes,
spatial_shapes_list,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output, attention_weights

View File

@ -799,7 +799,7 @@ class Mask2FormerLoss(nn.Module):
return num_masks
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
# Copied from transformers.models.oneformer.modeling_oneformer.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],

View File

@ -15,38 +15,32 @@
"""PyTorch OmDet-Turbo model."""
import math
import os
import warnings
from collections import OrderedDict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2CLS, ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_cuda_available,
replace_return_docstrings,
)
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_utils import PreTrainedModel
from ...utils import is_ninja_available, logging
from ...utils import logging
from ...utils.backbone_utils import load_backbone
from ..auto import AutoModel
from .configuration_omdet_turbo import OmDetTurboConfig
MultiScaleDeformableAttention = None
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OmDetTurboConfig"
@ -178,79 +172,60 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
classes_structure: Optional[torch.LongTensor] = None
# Copied from models.deformable_detr.load_cuda_kernels
def load_cuda_kernels():
from torch.utils.cpp_extension import load
global MultiScaleDeformableAttention
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
# Ignore copy
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id]
.flatten(2)
.transpose(1, 2)
.reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
return output.transpose(1, 2).contiguous()
class OmDetTurboLRUCache:
@ -332,55 +307,6 @@ class OmDetTurboVisionBackbone(nn.Module):
return outputs
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
class OmDetTurboMultiscaleDeformableAttention(nn.Module):
"""
@ -390,12 +316,7 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.attn = MultiScaleDeformableAttention()
if config.d_model % num_heads != 0:
raise ValueError(
@ -483,27 +404,16 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.attn(
value,
spatial_shapes,
spatial_shapes_list,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output, attention_weights

View File

@ -61,7 +61,6 @@ def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],

View File

@ -15,21 +15,18 @@
"""PyTorch RT-DETR model."""
import math
import os
import warnings
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2CLS, ACT2FN
from ...image_transforms import center_to_corners_format, corners_to_center_format
from ...integrations import use_kernel_forward_from_hub
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import compile_compatible_method_lru_cache
@ -37,9 +34,6 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
torch_int,
@ -50,96 +44,68 @@ from .configuration_rt_detr import RTDetrConfig
logger = logging.get_logger(__name__)
MultiScaleDeformableAttention = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.load_cuda_kernels
def load_cuda_kernels():
from torch.utils.cpp_extension import load
global MultiScaleDeformableAttention
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RTDetrConfig"
# TODO: Replace all occurrences of the checkpoint with the final one
_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd"
# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id]
.flatten(2)
.transpose(1, 2)
.reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
@dataclass
class RTDetrDecoderOutput(ModelOutput):
"""
@ -728,49 +694,6 @@ class RTDetrCSPRepLayer(nn.Module):
return self.conv3(hidden_state_1 + hidden_state_2)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
class RTDetrMultiscaleDeformableAttention(nn.Module):
"""
@ -780,12 +703,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.attn = MultiScaleDeformableAttention()
if config.d_model % num_heads != 0:
raise ValueError(
@ -872,27 +790,16 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.attn(
value,
spatial_shapes,
spatial_shapes_list,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output, attention_weights

View File

@ -21,8 +21,6 @@ from pathlib import Path
FILES_TO_FIND = [
"kernels/rwkv/wkv_cuda.cu",
"kernels/rwkv/wkv_op.cpp",
"kernels/deformable_detr/ms_deform_attn.h",
"kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh",
"kernels/falcon_mamba/selective_scan_with_ln_interface.py",
"kernels/falcon_mamba/__init__.py",
"kernels/__init__.py",

View File

@ -475,7 +475,6 @@ src/transformers/models/deberta/modeling_tf_deberta.py
src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
src/transformers/models/decision_transformer/modeling_decision_transformer.py
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
src/transformers/models/deformable_detr/load_custom.py
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py
src/transformers/models/deprecated/mctct/configuration_mctct.py