Fix valid ratio for Deformable Detr (#20958)

* fix: valid ratio has right value

* chore: remove unnecessary line

Co-authored-by: Jeongyeon Nam <jy.nam@navercorp.com>
This commit is contained in:
JeongYeon Nam 2023-01-03 23:43:26 +09:00 committed by GitHub
parent 9c9fe89f84
commit a9653400d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1509,8 +1509,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_height = torch.sum(mask[:, :, 0], 1)
valid_width = torch.sum(mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
@ -1687,9 +1687,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# revert valid_ratios
valid_ratios = ~valid_ratios.bool()
valid_ratios = valid_ratios.float()
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder