mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Use deformable_detr
kernel from the Hub (#36853)
* Use `deformable_detr` kernel from the Hub Remove the `deformable_detr` kernel from `kernels/` and use the pre-built kernel from the Hub instead. * Add license header * Add `kernels` as an extra `hub-kernels` Also add it to `testing`, so that the kernel replacement gets tested when using CUDA in CI.
This commit is contained in:
parent
2638d54e78
commit
f94b0c59f2
4
setup.py
4
setup.py
@ -129,6 +129,7 @@ _deps = [
|
|||||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||||
"keras>2.9,<2.16",
|
"keras>2.9,<2.16",
|
||||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||||
|
"kernels>=0.3.2,<0.4",
|
||||||
"librosa",
|
"librosa",
|
||||||
"natten>=0.14.6,<0.15.0",
|
"natten>=0.14.6,<0.15.0",
|
||||||
"nltk<=3.8.1",
|
"nltk<=3.8.1",
|
||||||
@ -301,8 +302,9 @@ extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
|||||||
extras["optuna"] = deps_list("optuna")
|
extras["optuna"] = deps_list("optuna")
|
||||||
extras["ray"] = deps_list("ray[tune]")
|
extras["ray"] = deps_list("ray[tune]")
|
||||||
extras["sigopt"] = deps_list("sigopt")
|
extras["sigopt"] = deps_list("sigopt")
|
||||||
|
extras["hub-kernels"] = deps_list("kernels")
|
||||||
|
|
||||||
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
|
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
|
||||||
|
|
||||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||||
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
|
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
|
||||||
|
@ -35,6 +35,7 @@ deps = {
|
|||||||
"kenlm": "kenlm",
|
"kenlm": "kenlm",
|
||||||
"keras": "keras>2.9,<2.16",
|
"keras": "keras>2.9,<2.16",
|
||||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||||
|
"kernels": "kernels>=0.3.2,<0.4",
|
||||||
"librosa": "librosa",
|
"librosa": "librosa",
|
||||||
"natten": "natten>=0.14.6,<0.15.0",
|
"natten": "natten>=0.14.6,<0.15.0",
|
||||||
"nltk": "nltk<=3.8.1",
|
"nltk": "nltk<=3.8.1",
|
||||||
|
@ -70,6 +70,12 @@ _import_structure = {
|
|||||||
"replace_with_higgs_linear",
|
"replace_with_higgs_linear",
|
||||||
],
|
],
|
||||||
"hqq": ["prepare_for_hqq_linear"],
|
"hqq": ["prepare_for_hqq_linear"],
|
||||||
|
"hub_kernels": [
|
||||||
|
"LayerRepository",
|
||||||
|
"register_kernel_mapping",
|
||||||
|
"replace_kernel_forward_from_hub",
|
||||||
|
"use_kernel_forward_from_hub",
|
||||||
|
],
|
||||||
"integration_utils": [
|
"integration_utils": [
|
||||||
"INTEGRATION_TO_CALLBACK",
|
"INTEGRATION_TO_CALLBACK",
|
||||||
"AzureMLCallback",
|
"AzureMLCallback",
|
||||||
@ -198,6 +204,12 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
|
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
|
||||||
from .hqq import prepare_for_hqq_linear
|
from .hqq import prepare_for_hqq_linear
|
||||||
|
from .hub_kernels import (
|
||||||
|
LayerRepository,
|
||||||
|
register_kernel_mapping,
|
||||||
|
replace_kernel_forward_from_hub,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
)
|
||||||
from .integration_utils import (
|
from .integration_utils import (
|
||||||
INTEGRATION_TO_CALLBACK,
|
INTEGRATION_TO_CALLBACK,
|
||||||
AzureMLCallback,
|
AzureMLCallback,
|
||||||
|
73
src/transformers/integrations/hub_kernels.py
Normal file
73
src/transformers/integrations/hub_kernels.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from kernels import (
|
||||||
|
Device,
|
||||||
|
LayerRepository,
|
||||||
|
register_kernel_mapping,
|
||||||
|
replace_kernel_forward_from_hub,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
_hub_kernels_available = True
|
||||||
|
|
||||||
|
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
|
||||||
|
"MultiScaleDeformableAttention": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/deformable-detr",
|
||||||
|
layer_name="MultiScaleDeformableAttention",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Stub to make decorators int transformers work when `kernels`
|
||||||
|
# is not installed.
|
||||||
|
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
class LayerRepository:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|
||||||
|
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
raise RuntimeError(
|
||||||
|
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_kernel_mapping(*args, **kwargs):
|
||||||
|
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|
||||||
|
_hub_kernels_available = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_hub_kernels_available():
|
||||||
|
return _hub_kernels_available
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LayerRepository",
|
||||||
|
"is_hub_kernels_available",
|
||||||
|
"use_kernel_forward_from_hub",
|
||||||
|
"register_kernel_mapping",
|
||||||
|
"replace_kernel_forward_from_hub",
|
||||||
|
]
|
@ -1,40 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_cpu_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
AT_ERROR("Not implement on cpu");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_cpu_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
AT_ERROR("Not implement on cpu");
|
|
||||||
}
|
|
@ -1,32 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_cpu_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_cpu_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
@ -1,159 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include "cuda/ms_deform_im2col_cuda.cuh"
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
at::DeviceGuard guard(value.device());
|
|
||||||
|
|
||||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
|
||||||
|
|
||||||
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
|
||||||
|
|
||||||
const int batch = value.size(0);
|
|
||||||
const int spatial_size = value.size(1);
|
|
||||||
const int num_heads = value.size(2);
|
|
||||||
const int channels = value.size(3);
|
|
||||||
|
|
||||||
const int num_levels = spatial_shapes.size(0);
|
|
||||||
|
|
||||||
const int num_query = sampling_loc.size(1);
|
|
||||||
const int num_point = sampling_loc.size(4);
|
|
||||||
|
|
||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
|
||||||
|
|
||||||
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
|
||||||
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
|
||||||
auto per_value_size = spatial_size * num_heads * channels;
|
|
||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
|
||||||
{
|
|
||||||
auto columns = output_n.select(0, n);
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
|
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
|
||||||
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
spatial_shapes.data_ptr<int64_t>(),
|
|
||||||
level_start_index.data_ptr<int64_t>(),
|
|
||||||
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
|
||||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
|
||||||
columns.data_ptr<scalar_t>());
|
|
||||||
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
output = output.view({batch, num_query, num_heads*channels});
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
at::DeviceGuard guard(value.device());
|
|
||||||
|
|
||||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
|
||||||
|
|
||||||
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
|
|
||||||
|
|
||||||
const int batch = value.size(0);
|
|
||||||
const int spatial_size = value.size(1);
|
|
||||||
const int num_heads = value.size(2);
|
|
||||||
const int channels = value.size(3);
|
|
||||||
|
|
||||||
const int num_levels = spatial_shapes.size(0);
|
|
||||||
|
|
||||||
const int num_query = sampling_loc.size(1);
|
|
||||||
const int num_point = sampling_loc.size(4);
|
|
||||||
|
|
||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
|
||||||
|
|
||||||
auto grad_value = at::zeros_like(value);
|
|
||||||
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
|
||||||
auto grad_attn_weight = at::zeros_like(attn_weight);
|
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
|
||||||
auto per_value_size = spatial_size * num_heads * channels;
|
|
||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
|
||||||
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
|
||||||
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
|
||||||
{
|
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
|
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
|
||||||
grad_output_g.data_ptr<scalar_t>(),
|
|
||||||
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
spatial_shapes.data_ptr<int64_t>(),
|
|
||||||
level_start_index.data_ptr<int64_t>(),
|
|
||||||
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
|
||||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
|
||||||
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
|
||||||
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight
|
|
||||||
};
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,46 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward_bf16(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step);
|
|
File diff suppressed because it is too large
Load Diff
@ -1,61 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cpu/ms_deform_attn_cpu.h"
|
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
#include "cuda/ms_deform_attn_cuda.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
if (value.is_cuda())
|
|
||||||
{
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
return ms_deform_attn_cuda_forward(
|
|
||||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
|
||||||
#else
|
|
||||||
AT_ERROR("Not compiled with GPU support");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
AT_ERROR("Not implemented on the CPU");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
if (value.is_cuda())
|
|
||||||
{
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
return ms_deform_attn_cuda_backward(
|
|
||||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
|
||||||
#else
|
|
||||||
AT_ERROR("Not compiled with GPU support");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
AT_ERROR("Not implemented on the CPU");
|
|
||||||
}
|
|
@ -1,16 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ms_deform_attn.h"
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
|
||||||
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Loading of Deformable DETR's CUDA kernels"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def load_cuda_kernels():
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
|
|
||||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
|
||||||
src_files = [
|
|
||||||
root / filename
|
|
||||||
for filename in [
|
|
||||||
"vision.cpp",
|
|
||||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
|
||||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
load(
|
|
||||||
"MultiScaleDeformableAttention",
|
|
||||||
src_files,
|
|
||||||
with_cuda=True,
|
|
||||||
extra_include_paths=[str(root)],
|
|
||||||
extra_cflags=["-DWITH_CUDA=1"],
|
|
||||||
extra_cuda_cflags=[
|
|
||||||
"-DCUDA_HAS_FP16=1",
|
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
import MultiScaleDeformableAttention as MSDA
|
|
||||||
|
|
||||||
return MSDA
|
|
@ -16,19 +16,16 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.autograd.function import once_differentiable
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@ -37,10 +34,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_ninja_available,
|
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_torch_cuda_available,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
@ -51,38 +45,6 @@ from .configuration_deformable_detr import DeformableDetrConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MultiScaleDeformableAttention = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_cuda_kernels():
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
|
|
||||||
global MultiScaleDeformableAttention
|
|
||||||
|
|
||||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
|
||||||
src_files = [
|
|
||||||
root / filename
|
|
||||||
for filename in [
|
|
||||||
"vision.cpp",
|
|
||||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
|
||||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
MultiScaleDeformableAttention = load(
|
|
||||||
"MultiScaleDeformableAttention",
|
|
||||||
src_files,
|
|
||||||
with_cuda=True,
|
|
||||||
extra_include_paths=[str(root)],
|
|
||||||
extra_cflags=["-DWITH_CUDA=1"],
|
|
||||||
extra_cuda_cflags=[
|
|
||||||
"-DCUDA_HAS_FP16=1",
|
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_timm_available():
|
if is_timm_available():
|
||||||
from timm import create_model
|
from timm import create_model
|
||||||
@ -94,52 +56,59 @@ _CONFIG_FOR_DOC = "DeformableDetrConfig"
|
|||||||
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
|
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleDeformableAttentionFunction(Function):
|
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||||
@staticmethod
|
class MultiScaleDeformableAttention(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
context,
|
self,
|
||||||
value,
|
value: Tensor,
|
||||||
value_spatial_shapes,
|
value_spatial_shapes: Tensor,
|
||||||
value_level_start_index,
|
value_spatial_shapes_list: List[Tuple],
|
||||||
sampling_locations,
|
level_start_index: Tensor,
|
||||||
attention_weights,
|
sampling_locations: Tensor,
|
||||||
im2col_step,
|
attention_weights: Tensor,
|
||||||
|
im2col_step: int,
|
||||||
):
|
):
|
||||||
context.im2col_step = im2col_step
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value,
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
value_spatial_shapes,
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
value_level_start_index,
|
sampling_value_list = []
|
||||||
sampling_locations,
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
attention_weights,
|
# batch_size, height*width, num_heads, hidden_dim
|
||||||
context.im2col_step,
|
# -> batch_size, height*width, num_heads*hidden_dim
|
||||||
|
# -> batch_size, num_heads*hidden_dim, height*width
|
||||||
|
# -> batch_size*num_heads, hidden_dim, height, width
|
||||||
|
value_l_ = (
|
||||||
|
value_list[level_id]
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||||
|
)
|
||||||
|
# batch_size, num_queries, num_heads, num_points, 2
|
||||||
|
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||||
|
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||||
|
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||||
|
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||||
|
sampling_value_l_ = nn.functional.grid_sample(
|
||||||
|
value_l_,
|
||||||
|
sampling_grid_l_,
|
||||||
|
mode="bilinear",
|
||||||
|
padding_mode="zeros",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
sampling_value_list.append(sampling_value_l_)
|
||||||
|
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||||
)
|
)
|
||||||
context.save_for_backward(
|
output = (
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
|
.sum(-1)
|
||||||
|
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||||
)
|
)
|
||||||
return output
|
return output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@once_differentiable
|
|
||||||
def backward(context, grad_output):
|
|
||||||
(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
) = context.saved_tensors
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
grad_output,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -564,48 +533,6 @@ def build_position_encoding(config):
|
|||||||
return position_embedding
|
return position_embedding
|
||||||
|
|
||||||
|
|
||||||
def multi_scale_deformable_attention(
|
|
||||||
value: Tensor,
|
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
|
||||||
sampling_locations: Tensor,
|
|
||||||
attention_weights: Tensor,
|
|
||||||
) -> Tensor:
|
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
||||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
|
||||||
sampling_value_list = []
|
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
|
||||||
# batch_size, height*width, num_heads, hidden_dim
|
|
||||||
# -> batch_size, height*width, num_heads*hidden_dim
|
|
||||||
# -> batch_size, num_heads*hidden_dim, height*width
|
|
||||||
# -> batch_size*num_heads, hidden_dim, height, width
|
|
||||||
value_l_ = (
|
|
||||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
||||||
)
|
|
||||||
# batch_size, num_queries, num_heads, num_points, 2
|
|
||||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
||||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
||||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
||||||
sampling_value_l_ = nn.functional.grid_sample(
|
|
||||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
||||||
)
|
|
||||||
sampling_value_list.append(sampling_value_l_)
|
|
||||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
||||||
.sum(-1)
|
|
||||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
||||||
)
|
|
||||||
return output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multiscale deformable attention as proposed in Deformable DETR.
|
Multiscale deformable attention as proposed in Deformable DETR.
|
||||||
@ -614,12 +541,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
|
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
self.attn = MultiScaleDeformableAttention()
|
||||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
|
||||||
try:
|
|
||||||
load_cuda_kernels()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
|
||||||
|
|
||||||
if config.d_model % num_heads != 0:
|
if config.d_model % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -706,27 +628,16 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
output = self.attn(
|
||||||
# PyTorch implementation
|
value,
|
||||||
output = multi_scale_deformable_attention(
|
spatial_shapes,
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
spatial_shapes_list,
|
||||||
)
|
level_start_index,
|
||||||
else:
|
sampling_locations,
|
||||||
try:
|
attention_weights,
|
||||||
# custom kernel
|
self.im2col_step,
|
||||||
output = MultiScaleDeformableAttentionFunction.apply(
|
)
|
||||||
value,
|
|
||||||
spatial_shapes,
|
|
||||||
level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
self.im2col_step,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# PyTorch implementation
|
|
||||||
output = multi_scale_deformable_attention(
|
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
@ -834,7 +745,11 @@ class DeformableDetrMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
if attn_output.size() != (
|
||||||
|
batch_size * self.num_heads,
|
||||||
|
target_len,
|
||||||
|
self.head_dim,
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
||||||
f" {attn_output.size()}"
|
f" {attn_output.size()}"
|
||||||
@ -854,7 +769,9 @@ class DeformableDetrEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
|
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
|
||||||
config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
|
config,
|
||||||
|
num_heads=config.encoder_attention_heads,
|
||||||
|
n_points=config.encoder_n_points,
|
||||||
)
|
)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
@ -1054,7 +971,11 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
|
_no_split_modules = [
|
||||||
|
r"DeformableDetrConvEncoder",
|
||||||
|
r"DeformableDetrEncoderLayer",
|
||||||
|
r"DeformableDetrDecoderLayer",
|
||||||
|
]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@ -1299,7 +1220,9 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(
|
||||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_states,
|
||||||
|
attentions=all_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1525,7 +1448,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
for _ in range(config.num_feature_levels - num_backbone_outs):
|
for _ in range(config.num_feature_levels - num_backbone_outs):
|
||||||
input_proj_list.append(
|
input_proj_list.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
config.d_model,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
nn.GroupNorm(32, config.d_model),
|
nn.GroupNorm(32, config.d_model),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1535,7 +1464,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
self.input_proj = nn.ModuleList(
|
self.input_proj = nn.ModuleList(
|
||||||
[
|
[
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
|
nn.Conv2d(
|
||||||
|
backbone.intermediate_channel_sizes[-1],
|
||||||
|
config.d_model,
|
||||||
|
kernel_size=1,
|
||||||
|
),
|
||||||
nn.GroupNorm(32, config.d_model),
|
nn.GroupNorm(32, config.d_model),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -1625,8 +1558,20 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||||
|
|
||||||
grid_y, grid_x = meshgrid(
|
grid_y, grid_x = meshgrid(
|
||||||
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
|
torch.linspace(
|
||||||
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
|
0,
|
||||||
|
height - 1,
|
||||||
|
height,
|
||||||
|
dtype=enc_output.dtype,
|
||||||
|
device=enc_output.device,
|
||||||
|
),
|
||||||
|
torch.linspace(
|
||||||
|
0,
|
||||||
|
width - 1,
|
||||||
|
width,
|
||||||
|
dtype=enc_output.dtype,
|
||||||
|
device=enc_output.device,
|
||||||
|
),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
||||||
@ -1802,7 +1747,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
topk = self.config.two_stage_num_proposals
|
topk = self.config.two_stage_num_proposals
|
||||||
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
||||||
topk_coords_logits = torch.gather(
|
topk_coords_logits = torch.gather(
|
||||||
enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
enc_outputs_coord_logits,
|
||||||
|
1,
|
||||||
|
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_coords_logits = topk_coords_logits.detach()
|
topk_coords_logits = topk_coords_logits.detach()
|
||||||
@ -1897,7 +1844,10 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
# Detection heads on top
|
# Detection heads on top
|
||||||
self.class_embed = nn.Linear(config.d_model, config.num_labels)
|
self.class_embed = nn.Linear(config.d_model, config.num_labels)
|
||||||
self.bbox_embed = DeformableDetrMLPPredictionHead(
|
self.bbox_embed = DeformableDetrMLPPredictionHead(
|
||||||
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
input_dim=config.d_model,
|
||||||
|
hidden_dim=config.d_model,
|
||||||
|
output_dim=4,
|
||||||
|
num_layers=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
prior_prob = 0.01
|
prior_prob = 0.01
|
||||||
@ -2033,7 +1983,13 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
||||||
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
logits,
|
||||||
|
labels,
|
||||||
|
self.device,
|
||||||
|
pred_boxes,
|
||||||
|
self.config,
|
||||||
|
outputs_class,
|
||||||
|
outputs_coord,
|
||||||
)
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if auxiliary_outputs is not None:
|
if auxiliary_outputs is not None:
|
||||||
|
@ -15,17 +15,13 @@
|
|||||||
"""PyTorch Grounding DINO model."""
|
"""PyTorch Grounding DINO model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.autograd.function import once_differentiable
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
@ -33,13 +29,13 @@ from ...file_utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_torch_cuda_available,
|
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import meshgrid
|
from ...pytorch_utils import meshgrid
|
||||||
from ...utils import is_ninja_available, logging
|
from ...utils import logging
|
||||||
from ...utils.backbone_utils import load_backbone
|
from ...utils.backbone_utils import load_backbone
|
||||||
from ..auto import AutoModel
|
from ..auto import AutoModel
|
||||||
from .configuration_grounding_dino import GroundingDinoConfig
|
from .configuration_grounding_dino import GroundingDinoConfig
|
||||||
@ -49,97 +45,68 @@ if is_timm_available():
|
|||||||
from timm import create_model
|
from timm import create_model
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
MultiScaleDeformableAttention = None
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from models.deformable_detr.load_cuda_kernels
|
|
||||||
def load_cuda_kernels():
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
|
|
||||||
global MultiScaleDeformableAttention
|
|
||||||
|
|
||||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
|
||||||
src_files = [
|
|
||||||
root / filename
|
|
||||||
for filename in [
|
|
||||||
"vision.cpp",
|
|
||||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
|
||||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
MultiScaleDeformableAttention = load(
|
|
||||||
"MultiScaleDeformableAttention",
|
|
||||||
src_files,
|
|
||||||
with_cuda=True,
|
|
||||||
extra_include_paths=[str(root)],
|
|
||||||
extra_cflags=["-DWITH_CUDA=1"],
|
|
||||||
extra_cuda_cflags=[
|
|
||||||
"-DCUDA_HAS_FP16=1",
|
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
|
||||||
class MultiScaleDeformableAttentionFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
context,
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
im2col_step,
|
|
||||||
):
|
|
||||||
context.im2col_step = im2col_step
|
|
||||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
context.save_for_backward(
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@once_differentiable
|
|
||||||
def backward(context, grad_output):
|
|
||||||
(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
) = context.saved_tensors
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
grad_output,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GroundingDinoConfig"
|
_CONFIG_FOR_DOC = "GroundingDinoConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny"
|
_CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||||
|
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||||
|
class MultiScaleDeformableAttention(nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Tensor,
|
||||||
|
value_spatial_shapes_list: List[Tuple],
|
||||||
|
level_start_index: Tensor,
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
|
im2col_step: int,
|
||||||
|
):
|
||||||
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
|
sampling_value_list = []
|
||||||
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
|
# batch_size, height*width, num_heads, hidden_dim
|
||||||
|
# -> batch_size, height*width, num_heads*hidden_dim
|
||||||
|
# -> batch_size, num_heads*hidden_dim, height*width
|
||||||
|
# -> batch_size*num_heads, hidden_dim, height, width
|
||||||
|
value_l_ = (
|
||||||
|
value_list[level_id]
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||||
|
)
|
||||||
|
# batch_size, num_queries, num_heads, num_points, 2
|
||||||
|
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||||
|
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||||
|
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||||
|
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||||
|
sampling_value_l_ = nn.functional.grid_sample(
|
||||||
|
value_l_,
|
||||||
|
sampling_grid_l_,
|
||||||
|
mode="bilinear",
|
||||||
|
padding_mode="zeros",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
sampling_value_list.append(sampling_value_l_)
|
||||||
|
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||||
|
)
|
||||||
|
output = (
|
||||||
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
|
.sum(-1)
|
||||||
|
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||||
|
)
|
||||||
|
return output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GroundingDinoDecoderOutput(ModelOutput):
|
class GroundingDinoDecoderOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@ -583,49 +550,6 @@ def build_position_encoding(config):
|
|||||||
return position_embedding
|
return position_embedding
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
|
||||||
def multi_scale_deformable_attention(
|
|
||||||
value: Tensor,
|
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
|
||||||
sampling_locations: Tensor,
|
|
||||||
attention_weights: Tensor,
|
|
||||||
) -> Tensor:
|
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
||||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
|
||||||
sampling_value_list = []
|
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
|
||||||
# batch_size, height*width, num_heads, hidden_dim
|
|
||||||
# -> batch_size, height*width, num_heads*hidden_dim
|
|
||||||
# -> batch_size, num_heads*hidden_dim, height*width
|
|
||||||
# -> batch_size*num_heads, hidden_dim, height, width
|
|
||||||
value_l_ = (
|
|
||||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
||||||
)
|
|
||||||
# batch_size, num_queries, num_heads, num_points, 2
|
|
||||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
||||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
||||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
||||||
sampling_value_l_ = nn.functional.grid_sample(
|
|
||||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
||||||
)
|
|
||||||
sampling_value_list.append(sampling_value_l_)
|
|
||||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
||||||
.sum(-1)
|
|
||||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
||||||
)
|
|
||||||
return output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->GroundingDino, Deformable DETR->Grounding DINO
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->GroundingDino, Deformable DETR->Grounding DINO
|
||||||
class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -635,12 +559,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
|||||||
def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int):
|
def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
self.attn = MultiScaleDeformableAttention()
|
||||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
|
||||||
try:
|
|
||||||
load_cuda_kernels()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
|
||||||
|
|
||||||
if config.d_model % num_heads != 0:
|
if config.d_model % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -727,23 +646,16 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
output = self.attn(
|
||||||
# PyTorch implementation
|
value,
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
spatial_shapes,
|
||||||
else:
|
spatial_shapes_list,
|
||||||
try:
|
level_start_index,
|
||||||
# custom kernel
|
sampling_locations,
|
||||||
output = MultiScaleDeformableAttentionFunction.apply(
|
attention_weights,
|
||||||
value,
|
self.im2col_step,
|
||||||
spatial_shapes,
|
)
|
||||||
level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
self.im2col_step,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# PyTorch implementation
|
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
|
@ -799,7 +799,7 @@ class Mask2FormerLoss(nn.Module):
|
|||||||
return num_masks
|
return num_masks
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# Copied from transformers.models.oneformer.modeling_oneformer.multi_scale_deformable_attention
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
@ -15,38 +15,32 @@
|
|||||||
"""PyTorch OmDet-Turbo model."""
|
"""PyTorch OmDet-Turbo model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.autograd.function import once_differentiable
|
|
||||||
|
|
||||||
from ...activations import ACT2CLS, ACT2FN
|
from ...activations import ACT2CLS, ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_torch_cuda_available,
|
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import is_ninja_available, logging
|
from ...utils import logging
|
||||||
from ...utils.backbone_utils import load_backbone
|
from ...utils.backbone_utils import load_backbone
|
||||||
from ..auto import AutoModel
|
from ..auto import AutoModel
|
||||||
from .configuration_omdet_turbo import OmDetTurboConfig
|
from .configuration_omdet_turbo import OmDetTurboConfig
|
||||||
|
|
||||||
|
|
||||||
MultiScaleDeformableAttention = None
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
_CONFIG_FOR_DOC = "OmDetTurboConfig"
|
_CONFIG_FOR_DOC = "OmDetTurboConfig"
|
||||||
|
|
||||||
@ -178,79 +172,60 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
|
|||||||
classes_structure: Optional[torch.LongTensor] = None
|
classes_structure: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from models.deformable_detr.load_cuda_kernels
|
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||||
def load_cuda_kernels():
|
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||||
from torch.utils.cpp_extension import load
|
class MultiScaleDeformableAttention(nn.Module):
|
||||||
|
def forward(
|
||||||
global MultiScaleDeformableAttention
|
self,
|
||||||
|
value: Tensor,
|
||||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
value_spatial_shapes: Tensor,
|
||||||
src_files = [
|
value_spatial_shapes_list: List[Tuple],
|
||||||
root / filename
|
level_start_index: Tensor,
|
||||||
for filename in [
|
sampling_locations: Tensor,
|
||||||
"vision.cpp",
|
attention_weights: Tensor,
|
||||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
im2col_step: int,
|
||||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
):
|
||||||
]
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
]
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
MultiScaleDeformableAttention = load(
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
"MultiScaleDeformableAttention",
|
sampling_value_list = []
|
||||||
src_files,
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
with_cuda=True,
|
# batch_size, height*width, num_heads, hidden_dim
|
||||||
extra_include_paths=[str(root)],
|
# -> batch_size, height*width, num_heads*hidden_dim
|
||||||
extra_cflags=["-DWITH_CUDA=1"],
|
# -> batch_size, num_heads*hidden_dim, height*width
|
||||||
extra_cuda_cflags=[
|
# -> batch_size*num_heads, hidden_dim, height, width
|
||||||
"-DCUDA_HAS_FP16=1",
|
value_l_ = (
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
value_list[level_id]
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
.flatten(2)
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
.transpose(1, 2)
|
||||||
],
|
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||||
)
|
)
|
||||||
|
# batch_size, num_queries, num_heads, num_points, 2
|
||||||
|
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||||
def multi_scale_deformable_attention(
|
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||||
value: Tensor,
|
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
sampling_value_l_ = nn.functional.grid_sample(
|
||||||
sampling_locations: Tensor,
|
value_l_,
|
||||||
attention_weights: Tensor,
|
sampling_grid_l_,
|
||||||
) -> Tensor:
|
mode="bilinear",
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
padding_mode="zeros",
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
align_corners=False,
|
||||||
# Ignore copy
|
)
|
||||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
sampling_value_list.append(sampling_value_l_)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||||
sampling_value_list = []
|
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
# batch_size, height*width, num_heads, hidden_dim
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
# -> batch_size, height*width, num_heads*hidden_dim
|
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||||
# -> batch_size, num_heads*hidden_dim, height*width
|
|
||||||
# -> batch_size*num_heads, hidden_dim, height, width
|
|
||||||
value_l_ = (
|
|
||||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
||||||
)
|
)
|
||||||
# batch_size, num_queries, num_heads, num_points, 2
|
output = (
|
||||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
.sum(-1)
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
||||||
sampling_value_l_ = nn.functional.grid_sample(
|
|
||||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
||||||
)
|
)
|
||||||
sampling_value_list.append(sampling_value_l_)
|
return output.transpose(1, 2).contiguous()
|
||||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
||||||
.sum(-1)
|
|
||||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
||||||
)
|
|
||||||
return output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class OmDetTurboLRUCache:
|
class OmDetTurboLRUCache:
|
||||||
@ -332,55 +307,6 @@ class OmDetTurboVisionBackbone(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
|
||||||
class MultiScaleDeformableAttentionFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
context,
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
im2col_step,
|
|
||||||
):
|
|
||||||
context.im2col_step = im2col_step
|
|
||||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
context.save_for_backward(
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@once_differentiable
|
|
||||||
def backward(context, grad_output):
|
|
||||||
(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
) = context.saved_tensors
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
grad_output,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
|
||||||
class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -390,12 +316,7 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
|||||||
def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
|
def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
self.attn = MultiScaleDeformableAttention()
|
||||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
|
||||||
try:
|
|
||||||
load_cuda_kernels()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
|
||||||
|
|
||||||
if config.d_model % num_heads != 0:
|
if config.d_model % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -483,27 +404,16 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels:
|
output = self.attn(
|
||||||
# PyTorch implementation
|
value,
|
||||||
output = multi_scale_deformable_attention(
|
spatial_shapes,
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
spatial_shapes_list,
|
||||||
)
|
level_start_index,
|
||||||
else:
|
sampling_locations,
|
||||||
try:
|
attention_weights,
|
||||||
# custom kernel
|
self.im2col_step,
|
||||||
output = MultiScaleDeformableAttentionFunction.apply(
|
)
|
||||||
value,
|
|
||||||
spatial_shapes,
|
|
||||||
level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
self.im2col_step,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# PyTorch implementation
|
|
||||||
output = multi_scale_deformable_attention(
|
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
|
@ -61,7 +61,6 @@ def _get_clones(module, N):
|
|||||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
@ -15,21 +15,18 @@
|
|||||||
"""PyTorch RT-DETR model."""
|
"""PyTorch RT-DETR model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.autograd.function import once_differentiable
|
|
||||||
|
|
||||||
from ...activations import ACT2CLS, ACT2FN
|
from ...activations import ACT2CLS, ACT2FN
|
||||||
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import compile_compatible_method_lru_cache
|
from ...pytorch_utils import compile_compatible_method_lru_cache
|
||||||
@ -37,9 +34,6 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_ninja_available,
|
|
||||||
is_torch_cuda_available,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
torch_int,
|
torch_int,
|
||||||
@ -50,96 +44,68 @@ from .configuration_rt_detr import RTDetrConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MultiScaleDeformableAttention = None
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.load_cuda_kernels
|
|
||||||
def load_cuda_kernels():
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
|
|
||||||
global MultiScaleDeformableAttention
|
|
||||||
|
|
||||||
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
|
|
||||||
src_files = [
|
|
||||||
root / filename
|
|
||||||
for filename in [
|
|
||||||
"vision.cpp",
|
|
||||||
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
|
|
||||||
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
MultiScaleDeformableAttention = load(
|
|
||||||
"MultiScaleDeformableAttention",
|
|
||||||
src_files,
|
|
||||||
with_cuda=True,
|
|
||||||
extra_include_paths=[str(root)],
|
|
||||||
extra_cflags=["-DWITH_CUDA=1"],
|
|
||||||
extra_cuda_cflags=[
|
|
||||||
"-DCUDA_HAS_FP16=1",
|
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
|
|
||||||
class MultiScaleDeformableAttentionFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
context,
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
im2col_step,
|
|
||||||
):
|
|
||||||
context.im2col_step = im2col_step
|
|
||||||
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
context.save_for_backward(
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@once_differentiable
|
|
||||||
def backward(context, grad_output):
|
|
||||||
(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
) = context.saved_tensors
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
|
|
||||||
value,
|
|
||||||
value_spatial_shapes,
|
|
||||||
value_level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
grad_output,
|
|
||||||
context.im2col_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "RTDetrConfig"
|
_CONFIG_FOR_DOC = "RTDetrConfig"
|
||||||
# TODO: Replace all occurrences of the checkpoint with the final one
|
# TODO: Replace all occurrences of the checkpoint with the final one
|
||||||
_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd"
|
_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from models.deformable_detr.MultiScaleDeformableAttention
|
||||||
|
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
||||||
|
class MultiScaleDeformableAttention(nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Tensor,
|
||||||
|
value_spatial_shapes_list: List[Tuple],
|
||||||
|
level_start_index: Tensor,
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
|
im2col_step: int,
|
||||||
|
):
|
||||||
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
|
sampling_value_list = []
|
||||||
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
|
# batch_size, height*width, num_heads, hidden_dim
|
||||||
|
# -> batch_size, height*width, num_heads*hidden_dim
|
||||||
|
# -> batch_size, num_heads*hidden_dim, height*width
|
||||||
|
# -> batch_size*num_heads, hidden_dim, height, width
|
||||||
|
value_l_ = (
|
||||||
|
value_list[level_id]
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
||||||
|
)
|
||||||
|
# batch_size, num_queries, num_heads, num_points, 2
|
||||||
|
# -> batch_size, num_heads, num_queries, num_points, 2
|
||||||
|
# -> batch_size*num_heads, num_queries, num_points, 2
|
||||||
|
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
||||||
|
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
||||||
|
sampling_value_l_ = nn.functional.grid_sample(
|
||||||
|
value_l_,
|
||||||
|
sampling_grid_l_,
|
||||||
|
mode="bilinear",
|
||||||
|
padding_mode="zeros",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
sampling_value_list.append(sampling_value_l_)
|
||||||
|
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
||||||
|
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
||||||
|
)
|
||||||
|
output = (
|
||||||
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
|
.sum(-1)
|
||||||
|
.view(batch_size, num_heads * hidden_dim, num_queries)
|
||||||
|
)
|
||||||
|
return output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RTDetrDecoderOutput(ModelOutput):
|
class RTDetrDecoderOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@ -728,49 +694,6 @@ class RTDetrCSPRepLayer(nn.Module):
|
|||||||
return self.conv3(hidden_state_1 + hidden_state_2)
|
return self.conv3(hidden_state_1 + hidden_state_2)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
|
||||||
def multi_scale_deformable_attention(
|
|
||||||
value: Tensor,
|
|
||||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
|
||||||
sampling_locations: Tensor,
|
|
||||||
attention_weights: Tensor,
|
|
||||||
) -> Tensor:
|
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
||||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
|
||||||
sampling_value_list = []
|
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
|
||||||
# batch_size, height*width, num_heads, hidden_dim
|
|
||||||
# -> batch_size, height*width, num_heads*hidden_dim
|
|
||||||
# -> batch_size, num_heads*hidden_dim, height*width
|
|
||||||
# -> batch_size*num_heads, hidden_dim, height, width
|
|
||||||
value_l_ = (
|
|
||||||
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
||||||
)
|
|
||||||
# batch_size, num_queries, num_heads, num_points, 2
|
|
||||||
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
||||||
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
||||||
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
||||||
sampling_value_l_ = nn.functional.grid_sample(
|
|
||||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
||||||
)
|
|
||||||
sampling_value_list.append(sampling_value_l_)
|
|
||||||
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
||||||
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
||||||
.sum(-1)
|
|
||||||
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
||||||
)
|
|
||||||
return output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
|
||||||
class RTDetrMultiscaleDeformableAttention(nn.Module):
|
class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -780,12 +703,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
|
def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_loaded = MultiScaleDeformableAttention is not None
|
self.attn = MultiScaleDeformableAttention()
|
||||||
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
|
|
||||||
try:
|
|
||||||
load_cuda_kernels()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
|
||||||
|
|
||||||
if config.d_model % num_heads != 0:
|
if config.d_model % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -872,27 +790,16 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
output = self.attn(
|
||||||
# PyTorch implementation
|
value,
|
||||||
output = multi_scale_deformable_attention(
|
spatial_shapes,
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
spatial_shapes_list,
|
||||||
)
|
level_start_index,
|
||||||
else:
|
sampling_locations,
|
||||||
try:
|
attention_weights,
|
||||||
# custom kernel
|
self.im2col_step,
|
||||||
output = MultiScaleDeformableAttentionFunction.apply(
|
)
|
||||||
value,
|
|
||||||
spatial_shapes,
|
|
||||||
level_start_index,
|
|
||||||
sampling_locations,
|
|
||||||
attention_weights,
|
|
||||||
self.im2col_step,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# PyTorch implementation
|
|
||||||
output = multi_scale_deformable_attention(
|
|
||||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
|
||||||
)
|
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
|
@ -21,8 +21,6 @@ from pathlib import Path
|
|||||||
FILES_TO_FIND = [
|
FILES_TO_FIND = [
|
||||||
"kernels/rwkv/wkv_cuda.cu",
|
"kernels/rwkv/wkv_cuda.cu",
|
||||||
"kernels/rwkv/wkv_op.cpp",
|
"kernels/rwkv/wkv_op.cpp",
|
||||||
"kernels/deformable_detr/ms_deform_attn.h",
|
|
||||||
"kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh",
|
|
||||||
"kernels/falcon_mamba/selective_scan_with_ln_interface.py",
|
"kernels/falcon_mamba/selective_scan_with_ln_interface.py",
|
||||||
"kernels/falcon_mamba/__init__.py",
|
"kernels/falcon_mamba/__init__.py",
|
||||||
"kernels/__init__.py",
|
"kernels/__init__.py",
|
||||||
|
@ -475,7 +475,6 @@ src/transformers/models/deberta/modeling_tf_deberta.py
|
|||||||
src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
|
src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
|
||||||
src/transformers/models/decision_transformer/modeling_decision_transformer.py
|
src/transformers/models/decision_transformer/modeling_decision_transformer.py
|
||||||
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
|
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
|
||||||
src/transformers/models/deformable_detr/load_custom.py
|
|
||||||
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
|
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
|
||||||
src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py
|
src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py
|
||||||
src/transformers/models/deprecated/mctct/configuration_mctct.py
|
src/transformers/models/deprecated/mctct/configuration_mctct.py
|
||||||
|
Loading…
Reference in New Issue
Block a user