mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Use deformable_detr
kernel from the Hub (#36853)
* Use `deformable_detr` kernel from the Hub Remove the `deformable_detr` kernel from `kernels/` and use the pre-built kernel from the Hub instead. * Add license header * Add `kernels` as an extra `hub-kernels` Also add it to `testing`, so that the kernel replacement gets tested when using CUDA in CI.
This commit is contained in:
parent
2638d54e78
commit
f94b0c59f2
4
setup.py
4
setup.py
@ -129,6 +129,7 @@ _deps = [
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
"kernels>=0.3.2,<0.4",
|
||||
"librosa",
|
||||
"natten>=0.14.6,<0.15.0",
|
||||
"nltk<=3.8.1",
|
||||
@ -301,8 +302,9 @@ extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
||||
extras["optuna"] = deps_list("optuna")
|
||||
extras["ray"] = deps_list("ray[tune]")
|
||||
extras["sigopt"] = deps_list("sigopt")
|
||||
extras["hub-kernels"] = deps_list("kernels")
|
||||
|
||||
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
|
||||
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
|
||||
|
||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
|
||||
|
@ -35,6 +35,7 @@ deps = {
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"kernels": "kernels>=0.3.2,<0.4",
|
||||
"librosa": "librosa",
|
||||
"natten": "natten>=0.14.6,<0.15.0",
|
||||
"nltk": "nltk<=3.8.1",
|
||||
|
@ -70,6 +70,12 @@ _import_structure = {
|
||||
"replace_with_higgs_linear",
|
||||
],
|
||||
"hqq": ["prepare_for_hqq_linear"],
|
||||
"hub_kernels": [
|
||||
"LayerRepository",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"use_kernel_forward_from_hub",
|
||||
],
|
||||
"integration_utils": [
|
||||
"INTEGRATION_TO_CALLBACK",
|
||||
"AzureMLCallback",
|
||||
@ -198,6 +204,12 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
|
||||
from .hqq import prepare_for_hqq_linear
|
||||
from .hub_kernels import (
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
from .integration_utils import (
|
||||
INTEGRATION_TO_CALLBACK,
|
||||
AzureMLCallback,
|
||||
|
73
src/transformers/integrations/hub_kernels.py
Normal file
73
src/transformers/integrations/hub_kernels.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
try:
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_hub_kernels_available = True
|
||||
|
||||
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
|
||||
"MultiScaleDeformableAttention": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/deformable-detr",
|
||||
layer_name="MultiScaleDeformableAttention",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
except ImportError:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
class LayerRepository:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||
)
|
||||
|
||||
def register_kernel_mapping(*args, **kwargs):
|
||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
_hub_kernels_available = False
|
||||
|
||||
|
||||
def is_hub_kernels_available():
|
||||
return _hub_kernels_available
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LayerRepository",
|
||||
"is_hub_kernels_available",
|
||||
"use_kernel_forward_from_hub",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
]
|
@ -1,40 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
|
||||
at::Tensor
|
||||
ms_deform_attn_cpu_forward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step)
|
||||
{
|
||||
AT_ERROR("Not implement on cpu");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
ms_deform_attn_cpu_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step)
|
||||
{
|
||||
AT_ERROR("Not implement on cpu");
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor
|
||||
ms_deform_attn_cpu_forward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
ms_deform_attn_cpu_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
@ -1,159 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "cuda/ms_deform_im2col_cuda.cuh"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
at::Tensor ms_deform_attn_cuda_forward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step)
|
||||
{
|
||||
at::DeviceGuard guard(value.device());
|
||||
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||
|
||||
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
||||
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
||||
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
||||
|
||||
const int batch = value.size(0);
|
||||
const int spatial_size = value.size(1);
|
||||
const int num_heads = value.size(2);
|
||||
const int channels = value.size(3);
|
||||
|
||||
const int num_levels = spatial_shapes.size(0);
|
||||
|
||||
const int num_query = sampling_loc.size(1);
|
||||
const int num_point = sampling_loc.size(4);
|
||||
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||
|
||||
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||
auto per_value_size = spatial_size * num_heads * channels;
|
||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data_ptr<int64_t>(),
|
||||
level_start_index.data_ptr<int64_t>(),
|
||||
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
||||
columns.data_ptr<scalar_t>());
|
||||
|
||||
}));
|
||||
}
|
||||
|
||||
output = output.view({batch, num_query, num_heads*channels});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step)
|
||||
{
|
||||
at::DeviceGuard guard(value.device());
|
||||
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
||||
|
||||
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
||||
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
||||
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
||||
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
|
||||
|
||||
const int batch = value.size(0);
|
||||
const int spatial_size = value.size(1);
|
||||
const int num_heads = value.size(2);
|
||||
const int channels = value.size(3);
|
||||
|
||||
const int num_levels = spatial_shapes.size(0);
|
||||
|
||||
const int num_query = sampling_loc.size(1);
|
||||
const int num_point = sampling_loc.size(4);
|
||||
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||
|
||||
auto grad_value = at::zeros_like(value);
|
||||
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
||||
auto grad_attn_weight = at::zeros_like(attn_weight);
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto per_value_size = spatial_size * num_heads * channels;
|
||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||
|
||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||
{
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data_ptr<int64_t>(),
|
||||
level_start_index.data_ptr<int64_t>(),
|
||||
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
||||
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
||||
|
||||
}));
|
||||
}
|
||||
|
||||
return {
|
||||
grad_value, grad_sampling_loc, grad_attn_weight
|
||||
};
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,46 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor ms_deform_attn_cuda_forward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
at::Tensor ms_deform_attn_cuda_forward_bf16(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step);
|
File diff suppressed because it is too large
Load Diff
@ -1,61 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cpu/ms_deform_attn_cpu.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include "cuda/ms_deform_attn_cuda.h"
|
||||
#endif
|
||||
|
||||
|
||||
at::Tensor
|
||||
ms_deform_attn_forward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step)
|
||||
{
|
||||
if (value.is_cuda())
|
||||
{
|
||||
#ifdef WITH_CUDA
|
||||
return ms_deform_attn_cuda_forward(
|
||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
ms_deform_attn_backward(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step)
|
||||
{
|
||||
if (value.is_cuda())
|
||||
{
|
||||
#ifdef WITH_CUDA
|
||||
return ms_deform_attn_cuda_backward(
|
||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include "ms_deform_attn.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
||||
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Loading of Deformable DETR's CUDA kernels"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
||||
]
|
||||
]
|
||||
|
||||
load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
)
|
||||
|
||||
import MultiScaleDeformableAttention as MSDA
|
||||
|
||||
return MSDA
|
@ -16,19 +16,16 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -37,10 +34,7 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -51,38 +45,6 @@ from .configuration_deformable_detr import DeformableDetrConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MultiScaleDeformableAttention = None
|
||||
|
||||
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global MultiScaleDeformableAttention
|
||||
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
||||
]
|
||||
]
|
||||
|
||||
MultiScaleDeformableAttention = load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
from timm import create_model
|
||||
@ -94,52 +56,59 @@ _CONFIG_FOR_DOC = "DeformableDetrConfig"
|
||||
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
|
||||
|
||||
|
||||
class MultiScaleDeformableAttentionFunction(Function):
|
||||
@staticmethod
|
||||
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||
class MultiScaleDeformableAttention(nn.Module):
|
||||
def forward(
|
||||
context,
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
im2col_step,
|
||||
self,
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Tensor,
|
||||
value_spatial_shapes_list: List[Tuple],
|
||||
level_start_index: Tensor,
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
im2col_step: int,
|
||||
):
|
||||
context.im2col_step = im2col_step
|
||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
context.im2col_step,
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id]
|
||||
.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_,
|
||||
sampling_grid_l_,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
context.save_for_backward(
|
||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(context, grad_output):
|
||||
(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
) = context.saved_tensors
|
||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
grad_output,
|
||||
context.im2col_step,
|
||||
)
|
||||
|
||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -564,48 +533,6 @@ def build_position_encoding(config):
|
||||
return position_embedding
|
||||
|
||||
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
"""
|
||||
Multiscale deformable attention as proposed in Deformable DETR.
|
||||
@ -614,12 +541,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
|
||||
super().__init__()
|
||||
|
||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
||||
try:
|
||||
load_cuda_kernels()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
||||
self.attn = MultiScaleDeformableAttention()
|
||||
|
||||
if config.d_model % num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -706,27 +628,16 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# custom kernel
|
||||
output = MultiScaleDeformableAttentionFunction.apply(
|
||||
value,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
except Exception:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
output = self.attn(
|
||||
value,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
@ -834,7 +745,11 @@ class DeformableDetrMultiheadAttention(nn.Module):
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
||||
if attn_output.size() != (
|
||||
batch_size * self.num_heads,
|
||||
target_len,
|
||||
self.head_dim,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
@ -854,7 +769,9 @@ class DeformableDetrEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
|
||||
config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
|
||||
config,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
n_points=config.encoder_n_points,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
@ -1054,7 +971,11 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
|
||||
_no_split_modules = [
|
||||
r"DeformableDetrConvEncoder",
|
||||
r"DeformableDetrEncoderLayer",
|
||||
r"DeformableDetrDecoderLayer",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -1299,7 +1220,9 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1525,7 +1448,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
for _ in range(config.num_feature_levels - num_backbone_outs):
|
||||
input_proj_list.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
config.d_model,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
nn.GroupNorm(32, config.d_model),
|
||||
)
|
||||
)
|
||||
@ -1535,7 +1464,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
self.input_proj = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
|
||||
nn.Conv2d(
|
||||
backbone.intermediate_channel_sizes[-1],
|
||||
config.d_model,
|
||||
kernel_size=1,
|
||||
),
|
||||
nn.GroupNorm(32, config.d_model),
|
||||
)
|
||||
]
|
||||
@ -1625,8 +1558,20 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||
|
||||
grid_y, grid_x = meshgrid(
|
||||
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
|
||||
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
|
||||
torch.linspace(
|
||||
0,
|
||||
height - 1,
|
||||
height,
|
||||
dtype=enc_output.dtype,
|
||||
device=enc_output.device,
|
||||
),
|
||||
torch.linspace(
|
||||
0,
|
||||
width - 1,
|
||||
width,
|
||||
dtype=enc_output.dtype,
|
||||
device=enc_output.device,
|
||||
),
|
||||
indexing="ij",
|
||||
)
|
||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
||||
@ -1802,7 +1747,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
topk = self.config.two_stage_num_proposals
|
||||
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
||||
topk_coords_logits = torch.gather(
|
||||
enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
||||
enc_outputs_coord_logits,
|
||||
1,
|
||||
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
||||
)
|
||||
|
||||
topk_coords_logits = topk_coords_logits.detach()
|
||||
@ -1897,7 +1844,10 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
# Detection heads on top
|
||||
self.class_embed = nn.Linear(config.d_model, config.num_labels)
|
||||
self.bbox_embed = DeformableDetrMLPPredictionHead(
|
||||
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
||||
input_dim=config.d_model,
|
||||
hidden_dim=config.d_model,
|
||||
output_dim=4,
|
||||
num_layers=3,
|
||||
)
|
||||
|
||||
prior_prob = 0.01
|
||||
@ -2033,7 +1983,13 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
||||
logits,
|
||||
labels,
|
||||
self.device,
|
||||
pred_boxes,
|
||||
self.config,
|
||||
outputs_class,
|
||||
outputs_coord,
|
||||
)
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
|
@ -15,17 +15,13 @@
|
||||
"""PyTorch Grounding DINO model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -33,13 +29,13 @@ from ...file_utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...utils import is_ninja_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from ..auto import AutoModel
|
||||
from .configuration_grounding_dino import GroundingDinoConfig
|
||||
@ -49,97 +45,68 @@ if is_timm_available():
|
||||
from timm import create_model
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MultiScaleDeformableAttention = None
|
||||
|
||||
|
||||
# Copied from models.deformable_detr.load_cuda_kernels
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global MultiScaleDeformableAttention
|
||||
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
||||
]
|
||||
]
|
||||
|
||||
MultiScaleDeformableAttention = load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
||||
class MultiScaleDeformableAttentionFunction(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
context,
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
im2col_step,
|
||||
):
|
||||
context.im2col_step = im2col_step
|
||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
context.im2col_step,
|
||||
)
|
||||
context.save_for_backward(
|
||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(context, grad_output):
|
||||
(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
) = context.saved_tensors
|
||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
grad_output,
|
||||
context.im2col_step,
|
||||
)
|
||||
|
||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "GroundingDinoConfig"
|
||||
_CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny"
|
||||
|
||||
|
||||
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||
class MultiScaleDeformableAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Tensor,
|
||||
value_spatial_shapes_list: List[Tuple],
|
||||
level_start_index: Tensor,
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
im2col_step: int,
|
||||
):
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id]
|
||||
.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_,
|
||||
sampling_grid_l_,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingDinoDecoderOutput(ModelOutput):
|
||||
"""
|
||||
@ -583,49 +550,6 @@ def build_position_encoding(config):
|
||||
return position_embedding
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->GroundingDino, Deformable DETR->Grounding DINO
|
||||
class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
"""
|
||||
@ -635,12 +559,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int):
|
||||
super().__init__()
|
||||
|
||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
||||
try:
|
||||
load_cuda_kernels()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
||||
self.attn = MultiScaleDeformableAttention()
|
||||
|
||||
if config.d_model % num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -727,23 +646,16 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||
else:
|
||||
try:
|
||||
# custom kernel
|
||||
output = MultiScaleDeformableAttentionFunction.apply(
|
||||
value,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
except Exception:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||
output = self.attn(
|
||||
value,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
|
@ -799,7 +799,7 @@ class Mask2FormerLoss(nn.Module):
|
||||
return num_masks
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
# Copied from transformers.models.oneformer.modeling_oneformer.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
|
@ -15,38 +15,32 @@
|
||||
"""PyTorch OmDet-Turbo model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ...activations import ACT2CLS, ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_cuda_available,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_ninja_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from ..auto import AutoModel
|
||||
from .configuration_omdet_turbo import OmDetTurboConfig
|
||||
|
||||
|
||||
MultiScaleDeformableAttention = None
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "OmDetTurboConfig"
|
||||
|
||||
@ -178,79 +172,60 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
|
||||
classes_structure: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
# Copied from models.deformable_detr.load_cuda_kernels
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global MultiScaleDeformableAttention
|
||||
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
||||
]
|
||||
]
|
||||
|
||||
MultiScaleDeformableAttention = load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
# Ignore copy
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||
class MultiScaleDeformableAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Tensor,
|
||||
value_spatial_shapes_list: List[Tuple],
|
||||
level_start_index: Tensor,
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
im2col_step: int,
|
||||
):
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id]
|
||||
.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_,
|
||||
sampling_grid_l_,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class OmDetTurboLRUCache:
|
||||
@ -332,55 +307,6 @@ class OmDetTurboVisionBackbone(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
||||
class MultiScaleDeformableAttentionFunction(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
context,
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
im2col_step,
|
||||
):
|
||||
context.im2col_step = im2col_step
|
||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
context.im2col_step,
|
||||
)
|
||||
context.save_for_backward(
|
||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(context, grad_output):
|
||||
(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
) = context.saved_tensors
|
||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
grad_output,
|
||||
context.im2col_step,
|
||||
)
|
||||
|
||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
|
||||
class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
||||
"""
|
||||
@ -390,12 +316,7 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
||||
def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
|
||||
super().__init__()
|
||||
|
||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
||||
try:
|
||||
load_cuda_kernels()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
||||
self.attn = MultiScaleDeformableAttention()
|
||||
|
||||
if config.d_model % num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -483,27 +404,16 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# custom kernel
|
||||
output = MultiScaleDeformableAttentionFunction.apply(
|
||||
value,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
except Exception:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
output = self.attn(
|
||||
value,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
|
@ -61,7 +61,6 @@ def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
|
@ -15,21 +15,18 @@
|
||||
"""PyTorch RT-DETR model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ...activations import ACT2CLS, ACT2FN
|
||||
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import compile_compatible_method_lru_cache
|
||||
@ -37,9 +34,6 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_torch_cuda_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
@ -50,96 +44,68 @@ from .configuration_rt_detr import RTDetrConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MultiScaleDeformableAttention = None
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.load_cuda_kernels
|
||||
def load_cuda_kernels():
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global MultiScaleDeformableAttention
|
||||
|
||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
||||
src_files = [
|
||||
root / filename
|
||||
for filename in [
|
||||
"vision.cpp",
|
||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
||||
]
|
||||
]
|
||||
|
||||
MultiScaleDeformableAttention = load(
|
||||
"MultiScaleDeformableAttention",
|
||||
src_files,
|
||||
with_cuda=True,
|
||||
extra_include_paths=[str(root)],
|
||||
extra_cflags=["-DWITH_CUDA=1"],
|
||||
extra_cuda_cflags=[
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
||||
class MultiScaleDeformableAttentionFunction(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
context,
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
im2col_step,
|
||||
):
|
||||
context.im2col_step = im2col_step
|
||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
context.im2col_step,
|
||||
)
|
||||
context.save_for_backward(
|
||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(context, grad_output):
|
||||
(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
) = context.saved_tensors
|
||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
grad_output,
|
||||
context.im2col_step,
|
||||
)
|
||||
|
||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "RTDetrConfig"
|
||||
# TODO: Replace all occurrences of the checkpoint with the final one
|
||||
_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd"
|
||||
|
||||
|
||||
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||
class MultiScaleDeformableAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Tensor,
|
||||
value_spatial_shapes_list: List[Tuple],
|
||||
level_start_index: Tensor,
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
im2col_step: int,
|
||||
):
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id]
|
||||
.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_,
|
||||
sampling_grid_l_,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTDetrDecoderOutput(ModelOutput):
|
||||
"""
|
||||
@ -728,49 +694,6 @@ class RTDetrCSPRepLayer(nn.Module):
|
||||
return self.conv3(hidden_state_1 + hidden_state_2)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
# batch_size, height*width, num_heads, hidden_dim
|
||||
# -> batch_size, height*width, num_heads*hidden_dim
|
||||
# -> batch_size, num_heads*hidden_dim, height*width
|
||||
# -> batch_size*num_heads, hidden_dim, height, width
|
||||
value_l_ = (
|
||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||
)
|
||||
# batch_size, num_queries, num_heads, num_points, 2
|
||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||
sampling_value_l_ = nn.functional.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
|
||||
class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
"""
|
||||
@ -780,12 +703,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
|
||||
super().__init__()
|
||||
|
||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
||||
try:
|
||||
load_cuda_kernels()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
||||
self.attn = MultiScaleDeformableAttention()
|
||||
|
||||
if config.d_model % num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -872,27 +790,16 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# custom kernel
|
||||
output = MultiScaleDeformableAttentionFunction.apply(
|
||||
value,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
except Exception:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
output = self.attn(
|
||||
value,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
|
@ -21,8 +21,6 @@ from pathlib import Path
|
||||
FILES_TO_FIND = [
|
||||
"kernels/rwkv/wkv_cuda.cu",
|
||||
"kernels/rwkv/wkv_op.cpp",
|
||||
"kernels/deformable_detr/ms_deform_attn.h",
|
||||
"kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh",
|
||||
"kernels/falcon_mamba/selective_scan_with_ln_interface.py",
|
||||
"kernels/falcon_mamba/__init__.py",
|
||||
"kernels/__init__.py",
|
||||
|
@ -475,7 +475,6 @@ src/transformers/models/deberta/modeling_tf_deberta.py
|
||||
src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
|
||||
src/transformers/models/decision_transformer/modeling_decision_transformer.py
|
||||
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
|
||||
src/transformers/models/deformable_detr/load_custom.py
|
||||
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
|
||||
src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py
|
||||
src/transformers/models/deprecated/mctct/configuration_mctct.py
|
||||
|
Loading…
Reference in New Issue
Block a user