Update output of SuperPointForKeypointDetection (#29809)

* Remove auto class

* Update ImagePointDescriptionOutput

* Update model outputs

* Rename output class

* Revert "Remove auto class"

This reverts commit ed4a8f549d.

* Address comments
This commit is contained in:
NielsRogge 2024-04-11 14:59:30 +02:00 committed by GitHub
parent 386ef34e7d
commit 5569552cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 26 deletions

View File

@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
@dataclass
class ImagePointDescriptionOutput(ModelOutput):
class SuperPointKeypointDescriptionOutput(ModelOutput):
"""
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput):
and which are padding.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Loss computed during training.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Relative (x, y) coordinates of predicted keypoints in a given image.
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput):
(also called feature maps) of the model at the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
keypoints: Optional[torch.IntTensor] = None
scores: Optional[torch.FloatTensor] = None
descriptors: Optional[torch.FloatTensor] = None
@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor = None,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImagePointDescriptionOutput]:
) -> Union[Tuple, SuperPointKeypointDescriptionOutput]:
"""
Examples:
@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
loss = None
if labels is not None:
raise ValueError(
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
)
raise ValueError("SuperPoint does not support training for now.")
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
batch_size = pixel_values.shape[0]
@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict:
return tuple(
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
)
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
return ImagePointDescriptionOutput(
last_hidden_state=last_hidden_state,
return SuperPointKeypointDescriptionOutput(
loss=loss,
keypoints=keypoints,
scores=scores,
descriptors=descriptors,

View File

@ -85,13 +85,17 @@ class SuperPointModelTester:
border_removal_distance=self.border_removal_distance,
)
def create_and_check_model(self, config, pixel_values):
def create_and_check_keypoint_detection(self, config, pixel_values):
model = SuperPointForKeypointDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
self.parent.assertEqual(result.keypoints.shape[-1], 2)
result = model(pixel_values, output_hidden_states=True)
self.parent.assertEqual(
result.last_hidden_state.shape,
result.hidden_states[-1].shape,
(
self.batch_size,
self.encoder_hidden_sizes[-1],
@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model(self):
def test_keypoint_detection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs()