diff --git a/setup.py b/setup.py index 9a538a9a3b8..8731f8187ab 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,7 @@ _deps = [ # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. + "kernels>=0.3.2,<0.4", "librosa", "natten>=0.14.6,<0.15.0", "nltk<=3.8.1", @@ -301,8 +302,9 @@ extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"] extras["optuna"] = deps_list("optuna") extras["ray"] = deps_list("ray[tune]") extras["sigopt"] = deps_list("sigopt") +extras["hub-kernels"] = deps_list("kernels") -extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"] +extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"] extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c87bd3a27f2..51682b6a706 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -35,6 +35,7 @@ deps = { "kenlm": "kenlm", "keras": "keras>2.9,<2.16", "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", + "kernels": "kernels>=0.3.2,<0.4", "librosa": "librosa", "natten": "natten>=0.14.6,<0.15.0", "nltk": "nltk<=3.8.1", diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index b5fd8a1e9df..da8c9cd4c6e 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -70,6 +70,12 @@ _import_structure = { "replace_with_higgs_linear", ], "hqq": ["prepare_for_hqq_linear"], + "hub_kernels": [ + "LayerRepository", + "register_kernel_mapping", + "replace_kernel_forward_from_hub", + "use_kernel_forward_from_hub", + ], "integration_utils": [ "INTEGRATION_TO_CALLBACK", "AzureMLCallback", @@ -198,6 +204,12 @@ if TYPE_CHECKING: ) from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear from .hqq import prepare_for_hqq_linear + from .hub_kernels import ( + LayerRepository, + register_kernel_mapping, + replace_kernel_forward_from_hub, + use_kernel_forward_from_hub, + ) from .integration_utils import ( INTEGRATION_TO_CALLBACK, AzureMLCallback, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py new file mode 100644 index 00000000000..b2ec6b53715 --- /dev/null +++ b/src/transformers/integrations/hub_kernels.py @@ -0,0 +1,73 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Union + + +try: + from kernels import ( + Device, + LayerRepository, + register_kernel_mapping, + replace_kernel_forward_from_hub, + use_kernel_forward_from_hub, + ) + + _hub_kernels_available = True + + _KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = { + "MultiScaleDeformableAttention": { + "cuda": LayerRepository( + repo_id="kernels-community/deformable-detr", + layer_name="MultiScaleDeformableAttention", + ) + } + } + + register_kernel_mapping(_KERNEL_MAPPING) + +except ImportError: + # Stub to make decorators int transformers work when `kernels` + # is not installed. + def use_kernel_forward_from_hub(*args, **kwargs): + def decorator(cls): + return cls + + return decorator + + class LayerRepository: + def __init__(self, *args, **kwargs): + raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") + + def replace_kernel_forward_from_hub(*args, **kwargs): + raise RuntimeError( + "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." + ) + + def register_kernel_mapping(*args, **kwargs): + raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + + _hub_kernels_available = False + + +def is_hub_kernels_available(): + return _hub_kernels_available + + +__all__ = [ + "LayerRepository", + "is_hub_kernels_available", + "use_kernel_forward_from_hub", + "register_kernel_mapping", + "replace_kernel_forward_from_hub", +] diff --git a/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp b/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp deleted file mode 100644 index 388a73d22d4..00000000000 --- a/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#include - -#include -#include - - -at::Tensor -ms_deform_attn_cpu_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step) -{ - AT_ERROR("Not implement on cpu"); -} - -std::vector -ms_deform_attn_cpu_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step) -{ - AT_ERROR("Not implement on cpu"); -} diff --git a/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h b/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h deleted file mode 100644 index 7eac8c8bcd1..00000000000 --- a/src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +++ /dev/null @@ -1,32 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#pragma once -#include - -at::Tensor -ms_deform_attn_cpu_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step); - -std::vector -ms_deform_attn_cpu_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step); - diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu deleted file mode 100644 index c2b3a462a1a..00000000000 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +++ /dev/null @@ -1,159 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#include -#include "cuda/ms_deform_im2col_cuda.cuh" - -#include -#include -#include -#include - -#pragma once -#include - - -at::Tensor ms_deform_attn_cuda_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step) -{ - at::DeviceGuard guard(value.device()); - - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - - AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); - - const int batch = value.size(0); - const int spatial_size = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - - const int num_levels = spatial_shapes.size(0); - - const int num_query = sampling_loc.size(1); - const int num_point = sampling_loc.size(4); - - const int im2col_step_ = std::min(batch, im2col_step); - - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - - auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); - - const int batch_n = im2col_step_; - auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); - auto per_value_size = spatial_size * num_heads * channels; - auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; - auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; - for (int n = 0; n < batch/im2col_step_; ++n) - { - auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { - ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), - value.data_ptr() + n * im2col_step_ * per_value_size, - spatial_shapes.data_ptr(), - level_start_index.data_ptr(), - sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, - batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - columns.data_ptr()); - - })); - } - - output = output.view({batch, num_query, num_heads*channels}); - - return output; -} - - -std::vector ms_deform_attn_cuda_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step) -{ - at::DeviceGuard guard(value.device()); - - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - - AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); - AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); - - const int batch = value.size(0); - const int spatial_size = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - - const int num_levels = spatial_shapes.size(0); - - const int num_query = sampling_loc.size(1); - const int num_point = sampling_loc.size(4); - - const int im2col_step_ = std::min(batch, im2col_step); - - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - - auto grad_value = at::zeros_like(value); - auto grad_sampling_loc = at::zeros_like(sampling_loc); - auto grad_attn_weight = at::zeros_like(attn_weight); - - const int batch_n = im2col_step_; - auto per_value_size = spatial_size * num_heads * channels; - auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; - auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; - auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); - - for (int n = 0; n < batch/im2col_step_; ++n) - { - auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { - ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), - grad_output_g.data_ptr(), - value.data_ptr() + n * im2col_step_ * per_value_size, - spatial_shapes.data_ptr(), - level_start_index.data_ptr(), - sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, - batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - grad_value.data_ptr() + n * im2col_step_ * per_value_size, - grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); - - })); - } - - return { - grad_value, grad_sampling_loc, grad_attn_weight - }; -} diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh deleted file mode 100644 index 20ae6892e4b..00000000000 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh +++ /dev/null @@ -1,1467 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#include - -#include -#include - -#include -#include -#include - -#include -#include - -#include - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) - - -at::Tensor ms_deform_attn_cuda_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step) -{ - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - - AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); - - const int batch = value.size(0); - const int spatial_size = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - - const int num_levels = spatial_shapes.size(0); - - const int num_query = sampling_loc.size(1); - const int num_point = sampling_loc.size(4); - - const int im2col_step_ = std::min(batch, im2col_step); - - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - - auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); - - const int batch_n = im2col_step_; - auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); - auto per_value_size = spatial_size * num_heads * channels; - auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; - auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; - for (int n = 0; n < batch/im2col_step_; ++n) - { - auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { - ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), - value.data_ptr() + n * im2col_step_ * per_value_size, - spatial_shapes.data_ptr(), - level_start_index.data_ptr(), - sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, - batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - columns.data_ptr()); - - })); - } - - output = output.view({batch, num_query, num_heads*channels}); - - return output; -} - - -std::vector ms_deform_attn_cuda_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step) -{ - - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - - AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); - AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); - - const int batch = value.size(0); - const int spatial_size = value.size(1); - const int num_heads = value.size(2); - const int channels = value.size(3); - - const int num_levels = spatial_shapes.size(0); - - const int num_query = sampling_loc.size(1); - const int num_point = sampling_loc.size(4); - - const int im2col_step_ = std::min(batch, im2col_step); - - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - - auto grad_value = at::zeros_like(value); - auto grad_sampling_loc = at::zeros_like(sampling_loc); - auto grad_attn_weight = at::zeros_like(attn_weight); - - const int batch_n = im2col_step_; - auto per_value_size = spatial_size * num_heads * channels; - auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; - auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; - auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); - - for (int n = 0; n < batch/im2col_step_; ++n) - { - auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { - ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), - grad_output_g.data_ptr(), - value.data_ptr() + n * im2col_step_ * per_value_size, - spatial_shapes.data_ptr(), - level_start_index.data_ptr(), - sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, - batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - grad_value.data_ptr() + n * im2col_step_ * per_value_size, - grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, - grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); - - })); - } - - return { - grad_value, grad_sampling_loc, grad_attn_weight - }; -} - -const int CUDA_NUM_THREADS = 1024; -inline int GET_BLOCKS(const int N, const int num_threads) -{ - return (N + num_threads - 1) / num_threads; -} - - -template -__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - } - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - - -template -__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c, - const scalar_t &top_grad, - const scalar_t &attn_weight, - scalar_t* &grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t top_grad_value = top_grad * attn_weight; - scalar_t grad_h_weight = 0, grad_w_weight = 0; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - grad_h_weight -= hw * v1; - grad_w_weight -= hh * v1; - atomicAdd(grad_value+ptr1, w1*top_grad_value); - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - grad_h_weight -= lw * v2; - grad_w_weight += hh * v2; - atomicAdd(grad_value+ptr2, w2*top_grad_value); - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - grad_h_weight += hw * v3; - grad_w_weight -= lh * v3; - atomicAdd(grad_value+ptr3, w3*top_grad_value); - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - grad_h_weight += lw * v4; - grad_w_weight += lh * v4; - atomicAdd(grad_value+ptr4, w4*top_grad_value); - } - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - *grad_attn_weight = top_grad * val; - *grad_sampling_loc = width * grad_w_weight * top_grad_value; - *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; -} - - -template -__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c, - const scalar_t &top_grad, - const scalar_t &attn_weight, - scalar_t* &grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t top_grad_value = top_grad * attn_weight; - scalar_t grad_h_weight = 0, grad_w_weight = 0; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - grad_h_weight -= hw * v1; - grad_w_weight -= hh * v1; - atomicAdd(grad_value+ptr1, w1*top_grad_value); - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - grad_h_weight -= lw * v2; - grad_w_weight += hh * v2; - atomicAdd(grad_value+ptr2, w2*top_grad_value); - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - grad_h_weight += hw * v3; - grad_w_weight -= lh * v3; - atomicAdd(grad_value+ptr3, w3*top_grad_value); - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - grad_h_weight += lw * v4; - grad_w_weight += lh * v4; - atomicAdd(grad_value+ptr4, w4*top_grad_value); - } - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - atomicAdd(grad_attn_weight, top_grad * val); - atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); - atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); -} - - -template -__global__ void ms_deformable_im2col_gpu_kernel(const int n, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *data_col) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - scalar_t *data_col_ptr = data_col + index; - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - scalar_t col = 0; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; - } - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - } - } - *data_col_ptr = col; - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - if (tid == 0) - { - scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; - int sid=2; - for (unsigned int tid = 1; tid < blockSize; ++tid) - { - _grad_w += cache_grad_sampling_loc[sid]; - _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; - sid += 2; - } - - - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockSize/2; s>0; s>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - } - __syncthreads(); - } - - if (tid == 0) - { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - if (tid == 0) - { - scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; - int sid=2; - for (unsigned int tid = 1; tid < blockDim.x; ++tid) - { - _grad_w += cache_grad_sampling_loc[sid]; - _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; - sid += 2; - } - - - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - if (tid + (s << 1) < spre) - { - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; - } - } - __syncthreads(); - } - - if (tid == 0) - { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - if (tid + (s << 1) < spre) - { - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; - } - } - __syncthreads(); - } - - if (tid == 0) - { - atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); - atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); - atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear_gm( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - grad_sampling_loc, grad_attn_weight); - } - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -void ms_deformable_im2col_cuda(cudaStream_t stream, - const scalar_t* data_value, - const int64_t* data_spatial_shapes, - const int64_t* data_level_start_index, - const scalar_t* data_sampling_loc, - const scalar_t* data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t* data_col) -{ - const int num_kernels = batch_size * num_query * num_heads * channels; - const int num_actual_kernels = batch_size * num_query * num_heads * channels; - const int num_threads = CUDA_NUM_THREADS; - ms_deformable_im2col_gpu_kernel - <<>>( - num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, - batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); - } - -} - -template -void ms_deformable_col2im_cuda(cudaStream_t stream, - const scalar_t* grad_col, - const scalar_t* data_value, - const int64_t * data_spatial_shapes, - const int64_t * data_level_start_index, - const scalar_t * data_sampling_loc, - const scalar_t * data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t* grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; - const int num_kernels = batch_size * num_query * num_heads * channels; - const int num_actual_kernels = batch_size * num_query * num_heads * channels; - if (channels > 1024) - { - if ((channels & 1023) == 0) - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - else - { - ms_deformable_col2im_gpu_kernel_gm - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - } - else{ - switch(channels) - { - case 1: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 2: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 4: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 8: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 16: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 32: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 64: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 128: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 256: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 512: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 1024: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - default: - if (channels < 64) - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - else - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - } - } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); - } - -} diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h deleted file mode 100644 index d8c21b4e54d..00000000000 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +++ /dev/null @@ -1,46 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#pragma once -#include - -at::Tensor ms_deform_attn_cuda_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step); - -at::Tensor ms_deform_attn_cuda_forward_bf16( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step); - -std::vector ms_deform_attn_cuda_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step); - -std::vector ms_deform_attn_cuda_backward_bf16( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step); diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh b/src/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh deleted file mode 100644 index 4fb544bf791..00000000000 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +++ /dev/null @@ -1,1327 +0,0 @@ -/*! -************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************** -* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) -* Copyright (c) 2018 Microsoft -************************************************************************** -*/ - -#include -#include -#include - -#include -#include - -#include - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) - -const int CUDA_NUM_THREADS = 1024; -inline int GET_BLOCKS(const int N, const int num_threads) -{ - return (N + num_threads - 1) / num_threads; -} - - -template -__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - } - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - - -template -__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c, - const scalar_t &top_grad, - const scalar_t &attn_weight, - scalar_t* &grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t top_grad_value = top_grad * attn_weight; - scalar_t grad_h_weight = 0, grad_w_weight = 0; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - grad_h_weight -= hw * v1; - grad_w_weight -= hh * v1; - atomicAdd(grad_value+ptr1, w1*top_grad_value); - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - grad_h_weight -= lw * v2; - grad_w_weight += hh * v2; - atomicAdd(grad_value+ptr2, w2*top_grad_value); - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - grad_h_weight += hw * v3; - grad_w_weight -= lh * v3; - atomicAdd(grad_value+ptr3, w3*top_grad_value); - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - grad_h_weight += lw * v4; - grad_w_weight += lh * v4; - atomicAdd(grad_value+ptr4, w4*top_grad_value); - } - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - *grad_attn_weight = top_grad * val; - *grad_sampling_loc = width * grad_w_weight * top_grad_value; - *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; -} - - -template -__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, - const int &height, const int &width, const int &nheads, const int &channels, - const scalar_t &h, const scalar_t &w, const int &m, const int &c, - const scalar_t &top_grad, - const scalar_t &attn_weight, - scalar_t* &grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int h_low = floor(h); - const int w_low = floor(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; - - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t top_grad_value = top_grad * attn_weight; - scalar_t grad_h_weight = 0, grad_w_weight = 0; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - grad_h_weight -= hw * v1; - grad_w_weight -= hh * v1; - atomicAdd(grad_value+ptr1, w1*top_grad_value); - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - grad_h_weight -= lw * v2; - grad_w_weight += hh * v2; - atomicAdd(grad_value+ptr2, w2*top_grad_value); - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - grad_h_weight += hw * v3; - grad_w_weight -= lh * v3; - atomicAdd(grad_value+ptr3, w3*top_grad_value); - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - grad_h_weight += lw * v4; - grad_w_weight += lh * v4; - atomicAdd(grad_value+ptr4, w4*top_grad_value); - } - - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - atomicAdd(grad_attn_weight, top_grad * val); - atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); - atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); -} - - -template -__global__ void ms_deformable_im2col_gpu_kernel(const int n, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *data_col) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - scalar_t *data_col_ptr = data_col + index; - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - scalar_t col = 0; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; - } - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - } - } - *data_col_ptr = col; - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - if (tid == 0) - { - scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; - int sid=2; - for (unsigned int tid = 1; tid < blockSize; ++tid) - { - _grad_w += cache_grad_sampling_loc[sid]; - _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; - sid += 2; - } - - - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockSize/2; s>0; s>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - } - __syncthreads(); - } - - if (tid == 0) - { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - if (tid == 0) - { - scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; - int sid=2; - for (unsigned int tid = 1; tid < blockDim.x; ++tid) - { - _grad_w += cache_grad_sampling_loc[sid]; - _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; - sid += 2; - } - - - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - if (tid + (s << 1) < spre) - { - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; - } - } - __syncthreads(); - } - - if (tid == 0) - { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - extern __shared__ int _s[]; - scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; - scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; - *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; - *(cache_grad_attn_weight+threadIdx.x)=0; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); - } - - __syncthreads(); - - for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) - { - if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; - if (tid + (s << 1) < spre) - { - cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; - cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; - cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; - } - } - __syncthreads(); - } - - if (tid == 0) - { - atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); - atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); - atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); - } - __syncthreads(); - - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, - const scalar_t *grad_col, - const scalar_t *data_value, - const int64_t *data_spatial_shapes, - const int64_t *data_level_start_index, - const scalar_t *data_sampling_loc, - const scalar_t *data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t *grad_value, - scalar_t *grad_sampling_loc, - scalar_t *grad_attn_weight) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - const int c_col = _temp % channels; - _temp /= channels; - const int sampling_index = _temp; - const int m_col = _temp % num_heads; - _temp /= num_heads; - [[maybe_unused]] const int q_col = _temp % num_query; - _temp /= num_query; - const int b_col = _temp; - - const scalar_t top_grad = grad_col[index]; - - int data_weight_ptr = sampling_index * num_levels * num_point; - int data_loc_w_ptr = data_weight_ptr << 1; - const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; - const int grad_weight_stride = 1; - const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; - - for (int l_col=0; l_col < num_levels; ++l_col) - { - const int level_start_id = data_level_start_index[l_col]; - const int spatial_h_ptr = l_col << 1; - const int spatial_h = data_spatial_shapes[spatial_h_ptr]; - const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; - const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - const scalar_t *data_value_ptr = data_value + value_ptr_offset; - scalar_t *grad_value_ptr = grad_value + value_ptr_offset; - - for (int p_col=0; p_col < num_point; ++p_col) - { - const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; - const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; - const scalar_t weight = data_attn_weight[data_weight_ptr]; - - const scalar_t h_im = loc_h * spatial_h - 0.5; - const scalar_t w_im = loc_w * spatial_w - 0.5; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) - { - ms_deform_attn_col2im_bilinear_gm( - data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, - top_grad, weight, grad_value_ptr, - grad_sampling_loc, grad_attn_weight); - } - data_weight_ptr += 1; - data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; - } - } - } -} - - -template -void ms_deformable_im2col_cuda(cudaStream_t stream, - const scalar_t* data_value, - const int64_t* data_spatial_shapes, - const int64_t* data_level_start_index, - const scalar_t* data_sampling_loc, - const scalar_t* data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t* data_col) -{ - const int num_kernels = batch_size * num_query * num_heads * channels; - const int num_actual_kernels = batch_size * num_query * num_heads * channels; - const int num_threads = CUDA_NUM_THREADS; - ms_deformable_im2col_gpu_kernel - <<>>( - num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, - batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); - } - -} - -template -void ms_deformable_col2im_cuda(cudaStream_t stream, - const scalar_t* grad_col, - const scalar_t* data_value, - const int64_t * data_spatial_shapes, - const int64_t * data_level_start_index, - const scalar_t * data_sampling_loc, - const scalar_t * data_attn_weight, - const int batch_size, - const int spatial_size, - const int num_heads, - const int channels, - const int num_levels, - const int num_query, - const int num_point, - scalar_t* grad_value, - scalar_t* grad_sampling_loc, - scalar_t* grad_attn_weight) -{ - const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; - const int num_kernels = batch_size * num_query * num_heads * channels; - const int num_actual_kernels = batch_size * num_query * num_heads * channels; - if (channels > 1024) - { - if ((channels & 1023) == 0) - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - else - { - ms_deformable_col2im_gpu_kernel_gm - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - } - else{ - switch(channels) - { - case 1: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 2: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 4: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 8: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 16: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 32: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 64: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 128: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 256: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 512: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - case 1024: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - break; - default: - if (channels < 64) - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v1 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - else - { - ms_deformable_col2im_gpu_kernel_shm_reduce_v2 - <<>>( - num_kernels, - grad_col, - data_value, - data_spatial_shapes, - data_level_start_index, - data_sampling_loc, - data_attn_weight, - batch_size, - spatial_size, - num_heads, - channels, - num_levels, - num_query, - num_point, - grad_value, - grad_sampling_loc, - grad_attn_weight); - } - } - } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); - } - -} diff --git a/src/transformers/kernels/deformable_detr/ms_deform_attn.h b/src/transformers/kernels/deformable_detr/ms_deform_attn.h deleted file mode 100644 index e649e65e37a..00000000000 --- a/src/transformers/kernels/deformable_detr/ms_deform_attn.h +++ /dev/null @@ -1,61 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#pragma once - -#include "cpu/ms_deform_attn_cpu.h" - -#ifdef WITH_CUDA -#include "cuda/ms_deform_attn_cuda.h" -#endif - - -at::Tensor -ms_deform_attn_forward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const int im2col_step) -{ - if (value.is_cuda()) - { -#ifdef WITH_CUDA - return ms_deform_attn_cuda_forward( - value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } - AT_ERROR("Not implemented on the CPU"); -} - -std::vector -ms_deform_attn_backward( - const at::Tensor &value, - const at::Tensor &spatial_shapes, - const at::Tensor &level_start_index, - const at::Tensor &sampling_loc, - const at::Tensor &attn_weight, - const at::Tensor &grad_output, - const int im2col_step) -{ - if (value.is_cuda()) - { -#ifdef WITH_CUDA - return ms_deform_attn_cuda_backward( - value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } - AT_ERROR("Not implemented on the CPU"); -} diff --git a/src/transformers/kernels/deformable_detr/vision.cpp b/src/transformers/kernels/deformable_detr/vision.cpp deleted file mode 100644 index 6ce3875568b..00000000000 --- a/src/transformers/kernels/deformable_detr/vision.cpp +++ /dev/null @@ -1,16 +0,0 @@ -/*! -************************************************************************************************** -* Deformable DETR -* Copyright (c) 2020 SenseTime. All Rights Reserved. -* Licensed under the Apache License, Version 2.0 [see LICENSE for details] -************************************************************************************************** -* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 -************************************************************************************************** -*/ - -#include "ms_deform_attn.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); - m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); -} \ No newline at end of file diff --git a/src/transformers/models/deformable_detr/load_custom.py b/src/transformers/models/deformable_detr/load_custom.py deleted file mode 100644 index 3c0b3a432be..00000000000 --- a/src/transformers/models/deformable_detr/load_custom.py +++ /dev/null @@ -1,50 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Loading of Deformable DETR's CUDA kernels""" - -import os -from pathlib import Path - - -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" - src_files = [ - root / filename - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - load( - "MultiScaleDeformableAttention", - src_files, - with_cuda=True, - extra_include_paths=[str(root)], - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - - import MultiScaleDeformableAttention as MSDA - - return MSDA diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 39d189418fe..6ffebca32dc 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -16,19 +16,16 @@ import copy import math -import os import warnings from dataclasses import dataclass -from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable from ...activations import ACT2FN +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -37,10 +34,7 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_ninja_available, is_timm_available, - is_torch_cuda_available, - is_torchdynamo_compiling, logging, replace_return_docstrings, requires_backends, @@ -51,38 +45,6 @@ from .configuration_deformable_detr import DeformableDetrConfig logger = logging.get_logger(__name__) -MultiScaleDeformableAttention = None - - -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - global MultiScaleDeformableAttention - - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" - src_files = [ - root / filename - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - MultiScaleDeformableAttention = load( - "MultiScaleDeformableAttention", - src_files, - with_cuda=True, - extra_include_paths=[str(root)], - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - if is_timm_available(): from timm import create_model @@ -94,52 +56,59 @@ _CONFIG_FOR_DOC = "DeformableDetrConfig" _CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" -class MultiScaleDeformableAttentionFunction(Function): - @staticmethod +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): def forward( - context, - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, ): - context.im2col_step = im2col_step - output = MultiScaleDeformableAttention.ms_deform_attn_forward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - context.im2col_step, + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points ) - context.save_for_backward( - value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) ) - return output - - @staticmethod - @once_differentiable - def backward(context, grad_output): - ( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - ) = context.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - context.im2col_step, - ) - - return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + return output.transpose(1, 2).contiguous() @dataclass @@ -564,48 +533,6 @@ def build_position_encoding(config): return position_embedding -def multi_scale_deformable_attention( - value: Tensor, - value_spatial_shapes: Union[Tensor, List[Tuple]], - sampling_locations: Tensor, - attention_weights: Tensor, -) -> Tensor: - batch_size, _, num_heads, hidden_dim = value.shape - _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 - sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): - # batch_size, height*width, num_heads, hidden_dim - # -> batch_size, height*width, num_heads*hidden_dim - # -> batch_size, num_heads*hidden_dim, height*width - # -> batch_size*num_heads, hidden_dim, height, width - value_l_ = ( - value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) - ) - # batch_size, num_queries, num_heads, num_points, 2 - # -> batch_size, num_heads, num_queries, num_points, 2 - # -> batch_size*num_heads, num_queries, num_points, 2 - sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) - # batch_size*num_heads, hidden_dim, num_queries, num_points - sampling_value_l_ = nn.functional.grid_sample( - value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False - ) - sampling_value_list.append(sampling_value_l_) - # (batch_size, num_queries, num_heads, num_levels, num_points) - # -> (batch_size, num_heads, num_queries, num_levels, num_points) - # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape( - batch_size * num_heads, 1, num_queries, num_levels * num_points - ) - output = ( - (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) - .sum(-1) - .view(batch_size, num_heads * hidden_dim, num_queries) - ) - return output.transpose(1, 2).contiguous() - - class DeformableDetrMultiscaleDeformableAttention(nn.Module): """ Multiscale deformable attention as proposed in Deformable DETR. @@ -614,12 +541,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): super().__init__() - kernel_loaded = MultiScaleDeformableAttention is not None - if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: - try: - load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + self.attn = MultiScaleDeformableAttention() if config.d_model % num_heads != 0: raise ValueError( @@ -706,27 +628,16 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling(): - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) - else: - try: - # custom kernel - output = MultiScaleDeformableAttentionFunction.apply( - value, - spatial_shapes, - level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - except Exception: - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) return output, attention_weights @@ -834,7 +745,11 @@ class DeformableDetrMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + if attn_output.size() != ( + batch_size * self.num_heads, + target_len, + self.head_dim, + ): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" f" {attn_output.size()}" @@ -854,7 +769,9 @@ class DeformableDetrEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model self.self_attn = DeformableDetrMultiscaleDeformableAttention( - config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points + config, + num_heads=config.encoder_attention_heads, + n_points=config.encoder_n_points, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -1054,7 +971,11 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" supports_gradient_checkpointing = True - _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] + _no_split_modules = [ + r"DeformableDetrConvEncoder", + r"DeformableDetrEncoderLayer", + r"DeformableDetrDecoderLayer", + ] def _init_weights(self, module): std = self.config.init_std @@ -1299,7 +1220,9 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -1525,7 +1448,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): for _ in range(config.num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( - nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1), + nn.Conv2d( + in_channels, + config.d_model, + kernel_size=3, + stride=2, + padding=1, + ), nn.GroupNorm(32, config.d_model), ) ) @@ -1535,7 +1464,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): self.input_proj = nn.ModuleList( [ nn.Sequential( - nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1), + nn.Conv2d( + backbone.intermediate_channel_sizes[-1], + config.d_model, + kernel_size=1, + ), nn.GroupNorm(32, config.d_model), ) ] @@ -1625,8 +1558,20 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = meshgrid( - torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device), - torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device), + torch.linspace( + 0, + height - 1, + height, + dtype=enc_output.dtype, + device=enc_output.device, + ), + torch.linspace( + 0, + width - 1, + width, + dtype=enc_output.dtype, + device=enc_output.device, + ), indexing="ij", ) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) @@ -1802,7 +1747,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): topk = self.config.two_stage_num_proposals topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_logits = torch.gather( - enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + enc_outputs_coord_logits, + 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4), ) topk_coords_logits = topk_coords_logits.detach() @@ -1897,7 +1844,10 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): # Detection heads on top self.class_embed = nn.Linear(config.d_model, config.num_labels) self.bbox_embed = DeformableDetrMLPPredictionHead( - input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + input_dim=config.d_model, + hidden_dim=config.d_model, + output_dim=4, + num_layers=3, ) prior_prob = 0.01 @@ -2033,7 +1983,13 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): loss, loss_dict, auxiliary_outputs = None, None, None if labels is not None: loss, loss_dict, auxiliary_outputs = self.loss_function( - logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, ) if not return_dict: if auxiliary_outputs is not None: diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index d8a4610cfd3..a8b244be5f6 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -15,17 +15,13 @@ """PyTorch Grounding DINO model.""" import math -import os import warnings from dataclasses import dataclass -from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable from ...activations import ACT2FN from ...file_utils import ( @@ -33,13 +29,13 @@ from ...file_utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_timm_available, - is_torch_cuda_available, replace_return_docstrings, requires_backends, ) +from ...integrations import use_kernel_forward_from_hub from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid -from ...utils import is_ninja_available, logging +from ...utils import logging from ...utils.backbone_utils import load_backbone from ..auto import AutoModel from .configuration_grounding_dino import GroundingDinoConfig @@ -49,97 +45,68 @@ if is_timm_available(): from timm import create_model -logger = logging.get_logger(__name__) - -MultiScaleDeformableAttention = None - - -# Copied from models.deformable_detr.load_cuda_kernels -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - global MultiScaleDeformableAttention - - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" - src_files = [ - root / filename - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - MultiScaleDeformableAttention = load( - "MultiScaleDeformableAttention", - src_files, - with_cuda=True, - extra_include_paths=[str(root)], - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction -class MultiScaleDeformableAttentionFunction(Function): - @staticmethod - def forward( - context, - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, - ): - context.im2col_step = im2col_step - output = MultiScaleDeformableAttention.ms_deform_attn_forward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - context.im2col_step, - ) - context.save_for_backward( - value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights - ) - return output - - @staticmethod - @once_differentiable - def backward(context, grad_output): - ( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - ) = context.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - context.im2col_step, - ) - - return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "GroundingDinoConfig" _CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny" +# Copied from models.deformable_detr.MultiScaleDeformableAttention +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + @dataclass class GroundingDinoDecoderOutput(ModelOutput): """ @@ -583,49 +550,6 @@ def build_position_encoding(config): return position_embedding -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention -def multi_scale_deformable_attention( - value: Tensor, - value_spatial_shapes: Union[Tensor, List[Tuple]], - sampling_locations: Tensor, - attention_weights: Tensor, -) -> Tensor: - batch_size, _, num_heads, hidden_dim = value.shape - _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 - sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): - # batch_size, height*width, num_heads, hidden_dim - # -> batch_size, height*width, num_heads*hidden_dim - # -> batch_size, num_heads*hidden_dim, height*width - # -> batch_size*num_heads, hidden_dim, height, width - value_l_ = ( - value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) - ) - # batch_size, num_queries, num_heads, num_points, 2 - # -> batch_size, num_heads, num_queries, num_points, 2 - # -> batch_size*num_heads, num_queries, num_points, 2 - sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) - # batch_size*num_heads, hidden_dim, num_queries, num_points - sampling_value_l_ = nn.functional.grid_sample( - value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False - ) - sampling_value_list.append(sampling_value_l_) - # (batch_size, num_queries, num_heads, num_levels, num_points) - # -> (batch_size, num_heads, num_queries, num_levels, num_points) - # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape( - batch_size * num_heads, 1, num_queries, num_levels * num_points - ) - output = ( - (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) - .sum(-1) - .view(batch_size, num_heads * hidden_dim, num_queries) - ) - return output.transpose(1, 2).contiguous() - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->GroundingDino, Deformable DETR->Grounding DINO class GroundingDinoMultiscaleDeformableAttention(nn.Module): """ @@ -635,12 +559,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module): def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int): super().__init__() - kernel_loaded = MultiScaleDeformableAttention is not None - if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: - try: - load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + self.attn = MultiScaleDeformableAttention() if config.d_model % num_heads != 0: raise ValueError( @@ -727,23 +646,16 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels or MultiScaleDeformableAttention is None: - # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) - else: - try: - # custom kernel - output = MultiScaleDeformableAttentionFunction.apply( - value, - spatial_shapes, - level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - except Exception: - # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) return output, attention_weights diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 887a617b824..ca9377e2827 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -799,7 +799,7 @@ class Mask2FormerLoss(nn.Module): return num_masks -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention +# Copied from transformers.models.oneformer.modeling_oneformer.multi_scale_deformable_attention def multi_scale_deformable_attention( value: Tensor, value_spatial_shapes: Union[Tensor, List[Tuple]], diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 39849d55844..61cc747ca75 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -15,38 +15,32 @@ """PyTorch OmDet-Turbo model.""" import math -import os import warnings from collections import OrderedDict from dataclasses import dataclass from functools import lru_cache -from pathlib import Path from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable from ...activations import ACT2CLS, ACT2FN from ...file_utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_torch_cuda_available, replace_return_docstrings, ) +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_utils import PreTrainedModel -from ...utils import is_ninja_available, logging +from ...utils import logging from ...utils.backbone_utils import load_backbone from ..auto import AutoModel from .configuration_omdet_turbo import OmDetTurboConfig -MultiScaleDeformableAttention = None - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "OmDetTurboConfig" @@ -178,79 +172,60 @@ class OmDetTurboObjectDetectionOutput(ModelOutput): classes_structure: Optional[torch.LongTensor] = None -# Copied from models.deformable_detr.load_cuda_kernels -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - global MultiScaleDeformableAttention - - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" - src_files = [ - root / filename - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - MultiScaleDeformableAttention = load( - "MultiScaleDeformableAttention", - src_files, - with_cuda=True, - extra_include_paths=[str(root)], - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention -def multi_scale_deformable_attention( - value: Tensor, - value_spatial_shapes: Union[Tensor, List[Tuple]], - sampling_locations: Tensor, - attention_weights: Tensor, -) -> Tensor: - batch_size, _, num_heads, hidden_dim = value.shape - _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - # Ignore copy - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 - sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): - # batch_size, height*width, num_heads, hidden_dim - # -> batch_size, height*width, num_heads*hidden_dim - # -> batch_size, num_heads*hidden_dim, height*width - # -> batch_size*num_heads, hidden_dim, height, width - value_l_ = ( - value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) +# Copied from models.deformable_detr.MultiScaleDeformableAttention +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points ) - # batch_size, num_queries, num_heads, num_points, 2 - # -> batch_size, num_heads, num_queries, num_points, 2 - # -> batch_size*num_heads, num_queries, num_points, 2 - sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) - # batch_size*num_heads, hidden_dim, num_queries, num_points - sampling_value_l_ = nn.functional.grid_sample( - value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) ) - sampling_value_list.append(sampling_value_l_) - # (batch_size, num_queries, num_heads, num_levels, num_points) - # -> (batch_size, num_heads, num_queries, num_levels, num_points) - # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape( - batch_size * num_heads, 1, num_queries, num_levels * num_points - ) - output = ( - (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) - .sum(-1) - .view(batch_size, num_heads * hidden_dim, num_queries) - ) - return output.transpose(1, 2).contiguous() + return output.transpose(1, 2).contiguous() class OmDetTurboLRUCache: @@ -332,55 +307,6 @@ class OmDetTurboVisionBackbone(nn.Module): return outputs -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction -class MultiScaleDeformableAttentionFunction(Function): - @staticmethod - def forward( - context, - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, - ): - context.im2col_step = im2col_step - output = MultiScaleDeformableAttention.ms_deform_attn_forward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - context.im2col_step, - ) - context.save_for_backward( - value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights - ) - return output - - @staticmethod - @once_differentiable - def backward(context, grad_output): - ( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - ) = context.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - context.im2col_step, - ) - - return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo class OmDetTurboMultiscaleDeformableAttention(nn.Module): """ @@ -390,12 +316,7 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module): def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int): super().__init__() - kernel_loaded = MultiScaleDeformableAttention is not None - if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: - try: - load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + self.attn = MultiScaleDeformableAttention() if config.d_model % num_heads != 0: raise ValueError( @@ -483,27 +404,16 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels: - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) - else: - try: - # custom kernel - output = MultiScaleDeformableAttentionFunction.apply( - value, - spatial_shapes, - level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - except Exception: - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) return output, attention_weights diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 585103f21b1..eef4b6a3c2d 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -61,7 +61,6 @@ def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention def multi_scale_deformable_attention( value: Tensor, value_spatial_shapes: Union[Tensor, List[Tuple]], diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 7a2ee51aa65..54dfc43bf86 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -15,21 +15,18 @@ """PyTorch RT-DETR model.""" import math -import os import warnings from dataclasses import dataclass from functools import partial -from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache @@ -37,9 +34,6 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_ninja_available, - is_torch_cuda_available, - is_torchdynamo_compiling, logging, replace_return_docstrings, torch_int, @@ -50,96 +44,68 @@ from .configuration_rt_detr import RTDetrConfig logger = logging.get_logger(__name__) -MultiScaleDeformableAttention = None - - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.load_cuda_kernels -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - global MultiScaleDeformableAttention - - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" - src_files = [ - root / filename - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - MultiScaleDeformableAttention = load( - "MultiScaleDeformableAttention", - src_files, - with_cuda=True, - extra_include_paths=[str(root)], - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction -class MultiScaleDeformableAttentionFunction(Function): - @staticmethod - def forward( - context, - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, - ): - context.im2col_step = im2col_step - output = MultiScaleDeformableAttention.ms_deform_attn_forward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - context.im2col_step, - ) - context.save_for_backward( - value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights - ) - return output - - @staticmethod - @once_differentiable - def backward(context, grad_output): - ( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - ) = context.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - context.im2col_step, - ) - - return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None - - -logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "RTDetrConfig" # TODO: Replace all occurrences of the checkpoint with the final one _CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd" +# Copied from models.deformable_detr.MultiScaleDeformableAttention +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + @dataclass class RTDetrDecoderOutput(ModelOutput): """ @@ -728,49 +694,6 @@ class RTDetrCSPRepLayer(nn.Module): return self.conv3(hidden_state_1 + hidden_state_2) -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention -def multi_scale_deformable_attention( - value: Tensor, - value_spatial_shapes: Union[Tensor, List[Tuple]], - sampling_locations: Tensor, - attention_weights: Tensor, -) -> Tensor: - batch_size, _, num_heads, hidden_dim = value.shape - _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 - sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): - # batch_size, height*width, num_heads, hidden_dim - # -> batch_size, height*width, num_heads*hidden_dim - # -> batch_size, num_heads*hidden_dim, height*width - # -> batch_size*num_heads, hidden_dim, height, width - value_l_ = ( - value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) - ) - # batch_size, num_queries, num_heads, num_points, 2 - # -> batch_size, num_heads, num_queries, num_points, 2 - # -> batch_size*num_heads, num_queries, num_points, 2 - sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) - # batch_size*num_heads, hidden_dim, num_queries, num_points - sampling_value_l_ = nn.functional.grid_sample( - value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False - ) - sampling_value_list.append(sampling_value_l_) - # (batch_size, num_queries, num_heads, num_levels, num_points) - # -> (batch_size, num_heads, num_queries, num_levels, num_points) - # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape( - batch_size * num_heads, 1, num_queries, num_levels * num_points - ) - output = ( - (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) - .sum(-1) - .view(batch_size, num_heads * hidden_dim, num_queries) - ) - return output.transpose(1, 2).contiguous() - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr class RTDetrMultiscaleDeformableAttention(nn.Module): """ @@ -780,12 +703,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module): def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int): super().__init__() - kernel_loaded = MultiScaleDeformableAttention is not None - if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: - try: - load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + self.attn = MultiScaleDeformableAttention() if config.d_model % num_heads != 0: raise ValueError( @@ -872,27 +790,16 @@ class RTDetrMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling(): - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) - else: - try: - # custom kernel - output = MultiScaleDeformableAttentionFunction.apply( - value, - spatial_shapes, - level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - except Exception: - # PyTorch implementation - output = multi_scale_deformable_attention( - value, spatial_shapes_list, sampling_locations, attention_weights - ) + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) return output, attention_weights diff --git a/utils/check_build.py b/utils/check_build.py index 9ac309bd675..929caeb7fa0 100644 --- a/utils/check_build.py +++ b/utils/check_build.py @@ -21,8 +21,6 @@ from pathlib import Path FILES_TO_FIND = [ "kernels/rwkv/wkv_cuda.cu", "kernels/rwkv/wkv_op.cpp", - "kernels/deformable_detr/ms_deform_attn.h", - "kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh", "kernels/falcon_mamba/selective_scan_with_ln_interface.py", "kernels/falcon_mamba/__init__.py", "kernels/__init__.py", diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 5c56ed144ca..61d452464c6 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -475,7 +475,6 @@ src/transformers/models/deberta/modeling_tf_deberta.py src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py src/transformers/models/decision_transformer/modeling_decision_transformer.py src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py -src/transformers/models/deformable_detr/load_custom.py src/transformers/models/deit/convert_deit_timm_to_pytorch.py src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py src/transformers/models/deprecated/mctct/configuration_mctct.py