mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Update output of SuperPointForKeypointDetection (#29809)
* Remove auto class
* Update ImagePointDescriptionOutput
* Update model outputs
* Rename output class
* Revert "Remove auto class"
This reverts commit ed4a8f549d
.
* Address comments
This commit is contained in:
parent
386ef34e7d
commit
5569552cf8
@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePointDescriptionOutput(ModelOutput):
|
||||
class SuperPointKeypointDescriptionOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
|
||||
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
|
||||
@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput):
|
||||
and which are padding.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
||||
Loss computed during training.
|
||||
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||
Relative (x, y) coordinates of predicted keypoints in a given image.
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
|
||||
@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput):
|
||||
(also called feature maps) of the model at the output of each stage.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
keypoints: Optional[torch.IntTensor] = None
|
||||
scores: Optional[torch.FloatTensor] = None
|
||||
descriptors: Optional[torch.FloatTensor] = None
|
||||
@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_values: torch.FloatTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, ImagePointDescriptionOutput]:
|
||||
) -> Union[Tuple, SuperPointKeypointDescriptionOutput]:
|
||||
"""
|
||||
Examples:
|
||||
|
||||
@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||
>>> inputs = processor(image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```"""
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
raise ValueError(
|
||||
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
|
||||
)
|
||||
raise ValueError("SuperPoint does not support training for now.")
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||
|
||||
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
|
||||
)
|
||||
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
|
||||
|
||||
return ImagePointDescriptionOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
return SuperPointKeypointDescriptionOutput(
|
||||
loss=loss,
|
||||
keypoints=keypoints,
|
||||
scores=scores,
|
||||
descriptors=descriptors,
|
||||
|
@ -85,13 +85,17 @@ class SuperPointModelTester:
|
||||
border_removal_distance=self.border_removal_distance,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
def create_and_check_keypoint_detection(self, config, pixel_values):
|
||||
model = SuperPointForKeypointDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
|
||||
self.parent.assertEqual(result.keypoints.shape[-1], 2)
|
||||
|
||||
result = model(pixel_values, output_hidden_states=True)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
result.hidden_states[-1].shape,
|
||||
(
|
||||
self.batch_size,
|
||||
self.encoder_hidden_sizes[-1],
|
||||
@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
def test_model(self):
|
||||
def test_keypoint_detection(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user