diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index e9948545859..5076fa6417d 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1509,8 +1509,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape - valid_height = torch.sum(~mask[:, :, 0], 1) - valid_width = torch.sum(~mask[:, 0, :], 1) + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) valid_ratio_heigth = valid_height.float() / height valid_ratio_width = valid_width.float() / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) @@ -1687,9 +1687,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) - - # revert valid_ratios - valid_ratios = ~valid_ratios.bool() valid_ratios = valid_ratios.float() # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder