mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
DeviceGuard added to use Deformable Attention more safely on multi-GPU (#32910)
* Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update ms_deform_attn_cuda.cu * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py * [empty] this is a empty commit --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
371b9c1486
commit
1dbd9d3693
@ -28,6 +28,8 @@ at::Tensor ms_deform_attn_cuda_forward(
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step)
|
||||
{
|
||||
at::DeviceGuard guard(value.device());
|
||||
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
||||
@ -92,6 +94,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &grad_output,
|
||||
const int im2col_step)
|
||||
{
|
||||
at::DeviceGuard guard(value.device());
|
||||
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||
|
Loading…
Reference in New Issue
Block a user