fix bug in mask2former: cost matrix is infeasible (#27897)

fix bug: cost matrix is infeasible
This commit is contained in:
Chenhao Xu 2023-12-12 03:19:16 +11:00 committed by GitHub
parent 7e35f37071
commit c0a354d8d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -471,6 +471,9 @@ class Mask2FormerHungarianMatcher(nn.Module):
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
# eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)