transformers/tests/models/lightglue/test_modeling_lightglue.py
StevenBucaille e5a9ce48f7
Add LightGlue model (#31718)
* init

* chore: various changes to LightGlue

* chore: various changes to LightGlue

* chore: various changes to LightGlue

* chore: various changes to LightGlue

* Fixed dynamo bug and image padding tests

* refactor: applied refactoring changes from SuperGlue's concat, batch and stack functions to LightGlue file

* tests: removed sdpa support and changed expected values

* chore: added some docs and refactoring

* chore: fixed copy to superpoint.image_processing_superpoint.convert_to_grayscale

* feat: adding batch implementation

* feat: added validation for preprocess and post process method to LightGlueImageProcessor

* chore: changed convert_lightglue_to_hf script to comply with new standard

* chore: changed lightglue test values to match new lightglue config pushed to hub

* chore: simplified convert_lightglue_to_hf conversion map

* feat: adding batching implementation

* chore: make style

* feat: added threshold to post_process_keypoint_matching method

* fix: added missing instructions that turns keypoints back to absolute coordinate before matching forward

* fix: added typehint and docs

* chore: make style

* [run-slow] lightglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching

* tests: added CUDA proof tests similar to SuperGlue

* chore: various changes to modeling_lightglue.py

- Added "Copies from" statements for copied functions from modeling_superglue.py
- Added missing docstrings
- Removed unused functions or classes
- Removed unnecessary statements
- Added missing typehints
- Added comments to the main forward method

* chore: various changes to convert_lightglue_to_hf.py

- Added model saving
- Added model reloading

* chore: fixed imports in lightglue files

* [run-slow] lightglue

* chore: make style

* [run-slow] lightglue

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* [run-slow] lightglue

* chore: Applied some suggestions from review

- Added missing typehints
- Refactor "cuda" to device variable
- Variable renaming
- LightGlue output order changed
- Make style

* fix: added missing grayscale argument in image processor in case use of SuperPoint keypoint detector

* fix: changed lightglue HF repo to lightglue_superpoint with grayscale default to True

* refactor: make keypoints `(batch_size, num_keypoints, keypoint_dim)` through forward and unsqueeze only before attention layer

* refactor: refactor do_layer_keypoint_pruning

* tests: added tests with no early stop and keypoint pruning

* refactor: various refactoring to modeling_lightglue.py

- Removed unused functions
- Renamed variables for consistency
- Added comments for clarity
- Set methods to private in LightGlueForKeypointMatching
- Replaced tensor initialization to list then concatenation
- Used more pythonic list comprehension for repetitive instructions

* refactor: added comments and renamed filter_matches to get_matches_from_scores

* tests: added copied from statement with superglue tests

* docs: added comment to prepare_keypoint_matching_output function in tests

* [run-slow] lightglue

* refactor: reordered _concat_early_stopped_outputs in LightGlue class

* [run-slow] lightglue

* docs: added lightglue.md model doc

* docs: added Optional typehint to LightGlueKeypointMatchingOutput

* chore: removed pad_images function

* chore: set do_grayscale default value to True in LightGlueImageProcessor

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* docs: added missing LightGlueConfig typehint in nn.Module __init__ methods

* docs: removed unnecessary code in docs

* docs: import SuperPointConfig only from a TYPE_CHECKING context

* chore: use PretrainedConfig arguments `num_hidden_layers` and `num_attention_heads` instead of `num_layers` and `num_heads`

* chore: added organization as arg in convert_lightglue_to_hf.py script

* refactor: set device variable

* chore: added "gelu" in LightGlueConfig as hidden_act parameter

* docs: added comments to reshape.flip.reshape instruction to perform cross attention

* refactor: used batched inference for keypoint detector forward pass

* fix: added fix for SDPA tests

* docs: fixed docstring for LightGlueImageProcessor

* [run-slow] lightglue

* refactor: removed unused line

* refactor: added missing arguments in LightGlueConfig init method

* docs: added missing LightGlueConfig typehint in init methods

* refactor: added checkpoint url as default variable to verify models output only if it is the default url

* fix: moved print message inside if statement

* fix: added log assignment r removal in convert script

* fix: got rid of confidence_thresholds as registered buffers

* refactor: applied suggestions from SuperGlue PR

* docs: changed copyright to 2025

* refactor: modular LightGlue

* fix: removed unnecessary import

* feat: added plot_keypoint_matching method to LightGlueImageProcessor with matplotlib soft dependency

* fix: added missing import error for matplotlib

* Updated convert script to push on ETH org

* fix: added missing licence

* fix: make fix-copies

* refactor: use cohere apply_rotary_pos_emb function

* fix: update model references to use ETH-CVG/lightglue_superpoint

* refactor: add and use intermediate_size attribute in config to inherit CLIPMLP for LightGlueMLP

* refactor: explicit variables instead of slicing

* refactor: use can_return_tuple decorator in LightGlue model

* fix: make fix-copies

* docs: Update model references in `lightglue.md` to use the correct pretrained model from ETH-CVG

* Refactor LightGlue configuration and processing classes

- Updated type hints for `keypoint_detector_config` in `LightGlueConfig` to use `SuperPointConfig` directly.
- Changed `size` parameter in `LightGlueImageProcessor` to be optional.
- Modified `position_embeddings` in `LightGlueAttention` and `LightGlueAttentionBlock` to be optional tuples.
- Cleaned up import statements across multiple files for better readability and consistency.

* refactor: Update LightGlue configuration to enforce eager attention implementation

- Added `attn_implementation="eager"` to `keypoint_detector_config` in `LightGlueConfig` and `LightGlueAttention` classes.
- Removed unnecessary logging related to attention implementation fallback.
- Cleaned up import statements for better readability.

* refactor: renamed message into attention_output

* fix: ensure device compatibility in LightGlueMatchAssignmentLayer descriptor normalization

- Updated the normalization of `m_descriptors` to use the correct device for the tensor, ensuring compatibility across different hardware setups.

* refactor: removed Conv layers from init_weights since LightGlue doesn't have any

* refactor: replace add_start_docstrings with auto_docstring in LightGlue models

- Updated LightGlue model classes to utilize the new auto_docstring utility for automatic documentation generation.
- Removed legacy docstring handling to streamline the code and improve maintainability.

* refactor: simplify LightGlue image processing tests by inheriting from SuperGlue

- Refactored `LightGlueImageProcessingTester` and `LightGlueImageProcessingTest` to inherit from their SuperGlue counterparts, reducing code duplication.
- Removed redundant methods and properties, streamlining the test setup and improving maintainability.

* test: forced eager attention implementation to LightGlue model tests

- Updated `LightGlueModelTester` to include `attn_implementation="eager"` in the model configuration.
- This change aligns the test setup with the recent updates in LightGlue configuration for eager attention.

* refactor: update LightGlue model references

* fix: import error

* test: enhance LightGlue image processing tests with setup method

- Added a setup method in `LightGlueImageProcessingTest` to initialize `LightGlueImageProcessingTester`.
- Included a docstring for `LightGlueImageProcessingTester` to clarify its purpose.

* refactor: added LightGlue image processing implementation to modular file

* refactor: moved attention blocks into the transformer layer

* fix: added missing import

* fix: added missing import in __all__ variable

* doc: added comment about enforcing eager attention because of SuperPoint

* refactor: added SuperPoint eager attention comment and moved functions to the closest they are used

---------

Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
2025-06-17 18:10:23 +02:00

585 lines
25 KiB
Python

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from datasets import load_dataset
from transformers.models.lightglue.configuration_lightglue import LightGlueConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
from transformers import LightGlueForKeypointMatching
if is_vision_available():
from transformers import AutoImageProcessor
class LightGlueModelTester:
def __init__(
self,
parent,
batch_size=2,
image_width=80,
image_height=60,
keypoint_detector_config={
"encoder_hidden_sizes": [32, 32, 64],
"decoder_hidden_size": 64,
"keypoint_decoder_dim": 65,
"descriptor_decoder_dim": 64,
"keypoint_threshold": 0.005,
"max_keypoints": 256,
"nms_radius": 4,
"border_removal_distance": 4,
},
descriptor_dim: int = 64,
num_layers: int = 2,
num_heads: int = 4,
depth_confidence: float = 1.0,
width_confidence: float = 1.0,
filter_threshold: float = 0.1,
matching_threshold: float = 0.0,
):
self.parent = parent
self.batch_size = batch_size
self.image_width = image_width
self.image_height = image_height
self.keypoint_detector_config = keypoint_detector_config
self.descriptor_dim = descriptor_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.depth_confidence = depth_confidence
self.width_confidence = width_confidence
self.filter_threshold = filter_threshold
self.matching_threshold = matching_threshold
def prepare_config_and_inputs(self):
# LightGlue expects a grayscale image as input
pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width])
config = self.get_config()
return config, pixel_values
def get_config(self):
return LightGlueConfig(
keypoint_detector_config=self.keypoint_detector_config,
descriptor_dim=self.descriptor_dim,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
depth_confidence=self.depth_confidence,
width_confidence=self.width_confidence,
filter_threshold=self.filter_threshold,
matching_threshold=self.matching_threshold,
attn_implementation="eager",
)
def create_and_check_model(self, config, pixel_values):
model = LightGlueForKeypointMatching(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
maximum_num_matches = result.mask.shape[-1]
self.parent.assertEqual(
result.keypoints.shape,
(self.batch_size, 2, maximum_num_matches, 2),
)
self.parent.assertEqual(
result.matches.shape,
(self.batch_size, 2, maximum_num_matches),
)
self.parent.assertEqual(
result.matching_scores.shape,
(self.batch_size, 2, maximum_num_matches),
)
self.parent.assertEqual(
result.prune.shape,
(self.batch_size, 2, maximum_num_matches),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class LightGlueModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (LightGlueForKeypointMatching,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = True
def setUp(self):
self.model_tester = LightGlueModelTester(self)
self.config_tester = ConfigTester(self, config_class=LightGlueConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
@unittest.skip(reason="LightGlueForKeypointMatching does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
def test_training(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="LightGlue does not output any loss term in the forward pass")
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
maximum_num_matches = outputs.mask.shape[-1]
hidden_states_sizes = [
self.model_tester.descriptor_dim,
self.model_tester.descriptor_dim,
self.model_tester.descriptor_dim * 2,
self.model_tester.descriptor_dim,
self.model_tester.descriptor_dim,
self.model_tester.descriptor_dim * 2,
self.model_tester.descriptor_dim,
] * self.model_tester.num_layers
for i, hidden_states_size in enumerate(hidden_states_sizes):
self.assertListEqual(
list(hidden_states[i].shape[-2:]),
[maximum_num_matches, hidden_states_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
def check_attention_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
maximum_num_matches = outputs.mask.shape[-1]
expected_attention_shape = [self.model_tester.num_heads, maximum_num_matches, maximum_num_matches]
for i, attention in enumerate(attentions):
self.assertListEqual(
list(attention.shape[-3:]),
expected_attention_shape,
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
check_attention_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
check_attention_output(inputs_dict, config, model_class)
@slow
def test_model_from_pretrained(self):
from_pretrained_ids = ["ETH-CVG/lightglue_superpoint"]
for model_name in from_pretrained_ids:
model = LightGlueForKeypointMatching.from_pretrained(model_name)
self.assertIsNotNone(model)
# Copied from tests.models.superglue.test_modeling_superglue.SuperGlueModelTest.test_forward_labels_should_be_none
def test_forward_labels_should_be_none(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
model_inputs = self._prepare_for_class(inputs_dict, model_class)
# Provide an arbitrary sized Tensor as labels to model inputs
model_inputs["labels"] = torch.rand((128, 128))
with self.assertRaises(ValueError) as cm:
model(**model_inputs)
self.assertEqual(ValueError, cm.exception.__class__)
def prepare_imgs():
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
image0 = dataset[0]["image"]
image1 = dataset[1]["image"]
image2 = dataset[2]["image"]
# [image1, image1] on purpose to test the model early stopping
return [[image2, image0], [image1, image1]]
@require_torch
@require_vision
class LightGlueModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") if is_vision_available() else None
@slow
def test_inference(self):
model = LightGlueForKeypointMatching.from_pretrained(
"ETH-CVG/lightglue_superpoint", attn_implementation="eager"
).to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
expected_number_of_matches0 = 140
expected_matches_values0 = torch.tensor(
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
dtype=torch.int64,
device=torch_device,
)
expected_matching_scores_values0 = torch.tensor(
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
device=torch_device,
)
expected_number_of_matches1 = 866
expected_matches_values1 = torch.tensor(
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
dtype=torch.int64,
device=torch_device,
)
expected_matching_scores_values1 = torch.tensor(
[
0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496,
0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923,
],
device=torch_device
) # fmt:skip
# expected_early_stopping_layer = 2
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
"""
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
specific test example). The consequence of having different number of keypoints is that the number of matches
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
match. The matching scores will also be different, as the keypoints are different. The checks here are less
strict to account for these inconsistencies.
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
Such CUDA inconsistencies can be found
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
"""
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
< 4
)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
< 4
)
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
@slow
def test_inference_without_early_stop(self):
model = LightGlueForKeypointMatching.from_pretrained(
"ETH-CVG/lightglue_superpoint", attn_implementation="eager", depth_confidence=1.0
).to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
expected_number_of_matches0 = 134
expected_matches_values0 = torch.tensor(
[-1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64
).to(torch_device)
expected_matching_scores_values0 = torch.tensor(
[0.0083, 0, 0.2022, 0.0621, 0, 0.0828, 0, 0, 0.0003, 0, 0, 0, 0.0960, 0, 0, 0.6940, 0, 0.7167, 0, 0.1512]
).to(torch_device)
expected_number_of_matches1 = 862
expected_matches_values1 = torch.tensor(
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64
).to(torch_device)
expected_matching_scores_values1 = torch.tensor(
[
0.4772,
0.3781,
0.0631,
0.9559,
0.8746,
0.9271,
0.4882,
0.5406,
0.9439,
0.1526,
0.5028,
0.4107,
0.5591,
0.9130,
0.7572,
0.0302,
0.4532,
0.0893,
0.9490,
0.4880,
]
).to(torch_device)
# expected_early_stopping_layer = 2
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
"""
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
specific test example). The consequence of having different number of keypoints is that the number of matches
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
match. The matching scores will also be different, as the keypoints are different. The checks here are less
strict to account for these inconsistencies.
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
Such CUDA inconsistencies can be found
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
"""
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
< 4
)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
< 4
)
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
@slow
def test_inference_without_early_stop_and_keypoint_pruning(self):
model = LightGlueForKeypointMatching.from_pretrained(
"ETH-CVG/lightglue_superpoint",
attn_implementation="eager",
depth_confidence=1.0,
width_confidence=1.0,
).to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
expected_number_of_matches0 = 144
expected_matches_values0 = torch.tensor(
[-1, -1, 17, -1, -1, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64
).to(torch_device)
expected_matching_scores_values0 = torch.tensor(
[
0.0699,
0.0302,
0.3356,
0.0820,
0,
0.2266,
0,
0,
0.0241,
0,
0,
0,
0.1674,
0,
0,
0.8114,
0,
0.8120,
0,
0.2936,
]
).to(torch_device)
expected_number_of_matches1 = 862
expected_matches_values1 = torch.tensor(
[10, 11, -1, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, -1, 26, -1, 28, 29], dtype=torch.int64
).to(torch_device)
expected_matching_scores_values1 = torch.tensor(
[
0.4772,
0.3781,
0.0631,
0.9559,
0.8746,
0.9271,
0.4882,
0.5406,
0.9439,
0.1526,
0.5028,
0.4107,
0.5591,
0.9130,
0.7572,
0.0302,
0.4532,
0.0893,
0.9490,
0.4880,
]
).to(torch_device)
# expected_early_stopping_layer = 2
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
"""
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
specific test example). The consequence of having different number of keypoints is that the number of matches
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
match. The matching scores will also be different, as the keypoints are different. The checks here are less
strict to account for these inconsistencies.
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
Such CUDA inconsistencies can be found
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
"""
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
< 4
)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
< 4
)
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)