diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 7bcf8d98251..33e50de7aa8 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -725,8 +725,8 @@ class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) - matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1) - matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1) + matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1) + matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1) if output_hidden_states: all_hidden_states = all_hidden_states + encoded_keypoints[1]