DeviceGuard added to use Deformable Attention more safely on multi-GPU (#32910)

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update ms_deform_attn_cuda.cu

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* [empty] this is a empty commit

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Donggeun Yu 2024-08-24 01:12:10 +09:00 committed by GitHub
parent 371b9c1486
commit 1dbd9d3693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,6 +28,8 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
at::DeviceGuard guard(value.device());
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
@ -92,6 +94,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
at::DeviceGuard guard(value.device());
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");