mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix valid ratio for Deformable Detr (#20958)
* fix: valid ratio has right value * chore: remove unnecessary line Co-authored-by: Jeongyeon Nam <jy.nam@navercorp.com>
This commit is contained in:
parent
9c9fe89f84
commit
a9653400d3
@ -1509,8 +1509,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
"""Get the valid ratio of all feature maps."""
|
||||
|
||||
_, height, width = mask.shape
|
||||
valid_height = torch.sum(~mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(~mask[:, 0, :], 1)
|
||||
valid_height = torch.sum(mask[:, :, 0], 1)
|
||||
valid_width = torch.sum(mask[:, 0, :], 1)
|
||||
valid_ratio_heigth = valid_height.float() / height
|
||||
valid_ratio_width = valid_width.float() / width
|
||||
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
|
||||
@ -1687,9 +1687,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
||||
|
||||
# revert valid_ratios
|
||||
valid_ratios = ~valid_ratios.bool()
|
||||
valid_ratios = valid_ratios.float()
|
||||
|
||||
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
||||
|
Loading…
Reference in New Issue
Block a user