transformers/tests/models/superglue/test_modeling_superglue.py
StevenBucaille abe57b6f17
Add SuperGlue model (#29886)
* Initial commit with template code generated by transformers-cli

* Multiple additions to SuperGlue implementation :

- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass

* Few changes for SuperGlue

* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference

* Reverted unintentional change

* Multiple changes :
 - Added SuperGlue to a bunch of places
 - Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
 - Added testing images

* Moved things in init files

* Added docs (to be finished depending on the final implementation)

* Added necessary imports and some doc

* Removed unnecessary import

* Fixed make fix-copies bug and ran it

* Deleted SuperGlueModel
Fixed convert script

* Added SuperGlueImageProcessor

* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences

* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances

* Added initial tests for SuperGlueImageProcessor

* Added AutoModelForImageMatching in missing places and tests

* Fixed keypoint_detector_output instructions

* Fix style

* Adapted to latest main changes

* Added integration test

* Fixed bugs to pass tests

* Added keypoints returned by keypoint detector in the output of SuperGlue

* Added doc to SuperGlue

* SuperGlue returning all attention and hidden states for a fixed number of keypoints

* Make style

* Changed SuperGlueImageProcessor tests

* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly

This reverts commit 5b3b669c

* Added back hidden_states and attentions masked outputs with tests

* Renamed ImageMatching occurences into KeypointMatching

* Changed SuperGlueImageProcessor to raise error when batch_size is not even

* Added docs and clarity to hidden state and attention grouping function

* Fixed some code and done refactoring

* Fixed typo in SuperPoint output doc

* Fixed some of the formatting and variable naming problems

* Removed useless function call

* Removed AutoModelForKeypointMatching

* Fixed SuperGlueImageProcessor to only accept paris of images

* Added more fixes to SuperGlueImageProcessor

* Simplified the batching of attention and hidden states

* Simplified stack functions

* Moved attention instructions into class

* Removed unused do_batch_norm argument

* Moved weight initialization to the proper place

* Replaced deepcopy for instantiation

* Fixed small bug

* Changed from stevenbucaille to magic-leap repo

* Renamed London Bridge images to Tower Bridge

* Fixed formatting

* Renamed remaining "london" to "tower"

* Apply suggestions from code review

Small changes in the docs

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Added AutoModelForKeypointMatching

* Changed images used in example

* Several changes to image_processing_superglue and style

* Fixed resample type hint

* Changed SuperGlueImageProcessor and added test case for list of 2 images

* Changed list_of_tuples implementation

* Fix in dummy objects

* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring

* Added missing docstring

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Moved forward block at bottom

* Added docstring to forward method

* Added docstring to match_image_pair method

* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature

* Removed AutoModelForKeypointMatching

* Removed image fixtures and added load_dataset

* Added padding of images in SuperGlueImageProcessor

* Cleaned up convert_superglue_to_hf script

* Added missing docs and fixed unused argument

* Fixed SuperGlueImageProcessor tests

* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape

* Added SuperGlueForKeypointMatching back to modeling_auto

* Fixed image processor padding test

* Changed SuperGlue docs

* changes:
 - Abstraction to batch, concat and stack of inconsistent tensors
 - Changed conv1d's to linears to match standard attention implementations
 - Renamed all tensors to be tensor0 and not tensor_0 and be consistent
 - Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
 - Various changes in docs, etc

* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code

* Formatting changes

* Reverted conv1d to linear conversion because of numerical differences

* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib

* fix: removed unnecessary test

* chore: removed commented code and added back hidden states transpositions

* chore: changed from "inconsistent" to "ragged" function names as suggested

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* docs: applied suggestions

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* docs: updated to display matched output

* chore: applied suggestion for check_image_pairs_input function

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function

* tests: simplified tests for image input format and shapes

* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly

* feat: several changes to address comments

Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts

Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()

Updated test to reflect changes in SuperGlue model

* refactor: changed validate_and_format_image_pairs function with clarity

* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class

* fix: fixed forgotten init weight change from last commit

* fix: fixed rebase mistake

* fix: removed leftover commented code

* fix: added typehint and changed some of arguments default values

* fix: fixed attribute default values for SuperGlueConfig

* feat: added SuperGlueImageProcessor post process keypoint matching method with tests

* fix: fixed SuperGlue attention and hidden state tuples aggregation

* chore: fixed mask optionality and reordered tensor reshapes to be cleaner

* chore: fixed docs and error message returned in validate_and_format_image_pairs function

* fix: fixed returned keypoints to be the ones that SuperPoint returns

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)

* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement

* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules

* WIP: implement Attention from an existing class (like BERT)

* docs: Changed docs to include more appealing matching plot

* WIP: Implement Attention

* chore: minor typehint change

* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation

* Revert "Fixed typo in SuperPoint output doc"

This reverts commit 2120390e82.

* chore: added comments in SuperGlueImageProcessor

* chore: changed SuperGlue organization HF repo to magic-leap-community

* [run-slow] refactor: small change in layer instantiation

* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community

* [run-slow] chore: make style

* chore: update image matching fixture dataset HF repository

* [run-slow] superglue

* tests: overwriting test_batching_equivalence

* [run-slow] superglue

* tests: changed test to cope with value changing depending on cuda version

* [run-slow] superglue

* tests: changed matching_threshold value

* [run-slow] superglue

* [run-slow] superglue

* tests: changed tests for integration

* [run-slow] superglue

* fix: Changed tensor view and permutations to match original implementation results

* fix: updated convert script and integration test to include last change in model

* fix: increase tolerance for CUDA variances

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* [run-slow] superglue

* chore: removed blank whitespaces

* [run-slow] superglue

* Revert SuperPoint image processor accident changes

* [run-slow] superglue

* refactor: reverted copy from BERT class

* tests: lower the tolerance in integration tests for SuperGlue

* [run-slow] superglue

* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors

* [run-slow] superglue

* fix: fixed imports in SuperGlue files

* chore: changed do_grayscale SuperGlueImageProcessing default value to True

* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor

* fix: set matching_threshold default value to 0.0 instead of 0.2

* feat: added matching_threshold to post_process_keypoint_matching method

* docs: update superglue.md to include matching_threshold parameter

* docs: updated SuperGlueConfig docstring for matching_threshold default value

* refactor: removed unnecessary parameters in SuperGlueConfig

* fix: changed from matching_threshold to threshold

* fix: re-revert changes to make SuperGlue attention classes copies of BERT

* [run-slow] superglue

* fix: added missing device argument in post_processing method

* [run-slow] superglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)

* fix: add device to image_sizes tensor instantiation

* tests: added checks on do_grayscale test

* chore: reordered and added Optional typehint to KeypointMatchingOutput

* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods

* LightGluePR suggestions:
- use batching to get keypoint detection

* refactor: processing images done in 1 for loop instead of 4

* fix: use @ instead of torch.einsum for scores computation

* style: added #fmt skip to long tensor values

* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones

* refactor: prepare_imgs

* refactor: simplified `validate_and_format_image_pairs`

* docs: fixed doc

---------

Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
2025-01-20 10:32:39 +00:00

428 lines
18 KiB
Python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from typing import List
from datasets import load_dataset
from transformers.models.superglue.configuration_superglue import SuperGlueConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
from transformers import SuperGlueForKeypointMatching
if is_vision_available():
from transformers import AutoImageProcessor
class SuperGlueModelTester:
def __init__(
self,
parent,
batch_size=2,
image_width=80,
image_height=60,
keypoint_detector_config=None,
hidden_size: int = 64,
keypoint_encoder_sizes: List[int] = [32, 64],
gnn_layers_types: List[str] = ["self", "cross"] * 2,
num_attention_heads: int = 4,
sinkhorn_iterations: int = 100,
matching_threshold: float = 0.2,
):
if keypoint_detector_config is None:
keypoint_detector_config = {
"encoder_hidden_sizes": [32, 64],
"decoder_hidden_size": 64,
"keypoint_decoder_dim": 65,
"descriptor_decoder_dim": 64,
"keypoint_threshold": 0.005,
"max_keypoints": 256,
"nms_radius": 4,
"border_removal_distance": 4,
}
self.parent = parent
self.batch_size = batch_size
self.image_width = image_width
self.image_height = image_height
self.keypoint_detector_config = keypoint_detector_config
self.hidden_size = hidden_size
self.keypoint_encoder_sizes = keypoint_encoder_sizes
self.gnn_layers_types = gnn_layers_types
self.num_attention_heads = num_attention_heads
self.sinkhorn_iterations = sinkhorn_iterations
self.matching_threshold = matching_threshold
def prepare_config_and_inputs(self):
# SuperGlue expects a grayscale image as input
pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width])
config = self.get_config()
return config, pixel_values
def get_config(self):
return SuperGlueConfig(
keypoint_detector_config=self.keypoint_detector_config,
hidden_size=self.hidden_size,
keypoint_encoder_sizes=self.keypoint_encoder_sizes,
gnn_layers_types=self.gnn_layers_types,
num_attention_heads=self.num_attention_heads,
sinkhorn_iterations=self.sinkhorn_iterations,
matching_threshold=self.matching_threshold,
)
def create_and_check_model(self, config, pixel_values):
model = SuperGlueForKeypointMatching(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
maximum_num_matches = result.mask.shape[-1]
self.parent.assertEqual(
result.keypoints.shape,
(self.batch_size, 2, maximum_num_matches, 2),
)
self.parent.assertEqual(
result.matches.shape,
(self.batch_size, 2, maximum_num_matches),
)
self.parent.assertEqual(
result.matching_scores.shape,
(self.batch_size, 2, maximum_num_matches),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SuperGlueModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperGlueForKeypointMatching,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = True
def setUp(self):
self.model_tester = SuperGlueModelTester(self)
self.config_tester = ConfigTester(self, config_class=SuperGlueConfig, has_text_modality=False, hidden_size=64)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
@unittest.skip(reason="SuperGlueForKeypointMatching does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
def test_training(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="SuperGlue does not output any loss term in the forward pass")
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
maximum_num_matches = outputs.mask.shape[-1]
hidden_states_sizes = (
self.model_tester.keypoint_encoder_sizes
+ [self.model_tester.hidden_size]
+ [self.model_tester.hidden_size, self.model_tester.hidden_size * 2]
* len(self.model_tester.gnn_layers_types)
+ [self.model_tester.hidden_size] * 2
)
for i, hidden_states_size in enumerate(hidden_states_sizes):
self.assertListEqual(
list(hidden_states[i].shape[-2:]),
[hidden_states_size, maximum_num_matches],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
def check_attention_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
maximum_num_matches = outputs.mask.shape[-1]
expected_attention_shape = [
self.model_tester.num_attention_heads,
maximum_num_matches,
maximum_num_matches,
]
for i, attention in enumerate(attentions):
self.assertListEqual(
list(attention.shape[-3:]),
expected_attention_shape,
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
check_attention_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
check_attention_output(inputs_dict, config, model_class)
@slow
def test_model_from_pretrained(self):
from_pretrained_ids = ["magic-leap-community/superglue_indoor", "magic-leap-community/superglue_outdoor"]
for model_name in from_pretrained_ids:
model = SuperGlueForKeypointMatching.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
model_inputs = self._prepare_for_class(inputs_dict, model_class)
# Provide an arbitrary sized Tensor as labels to model inputs
model_inputs["labels"] = torch.rand((128, 128))
with self.assertRaises(ValueError) as cm:
model(**model_inputs)
self.assertEqual(ValueError, cm.exception.__class__)
def test_batching_equivalence(self):
"""
Overwriting ModelTesterMixin.test_batching_equivalence since SuperGlue returns `matching_scores` tensors full of
zeros which causes the test to fail, because cosine_similarity of two zero tensors is 0.
Discussed here : https://github.com/huggingface/transformers/pull/29886#issuecomment-2481539481
"""
def recursive_check(batched_object, single_row_object, model_name, key):
if isinstance(batched_object, (list, tuple)):
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
elif isinstance(batched_object, dict):
for batched_object_value, single_row_object_value in zip(
batched_object.values(), single_row_object.values()
):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
return
elif batched_object.dim() == 0:
return
else:
# indexing the first element does not always work
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
slice_ids = [slice(0, index) for index in single_row_object.shape]
batched_row = batched_object[slice_ids]
self.assertFalse(
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
)
self.assertFalse(
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
)
self.assertTrue(
(equivalence(batched_row, single_row_object)) <= 1e-03,
msg=(
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
f"Difference={equivalence(batched_row, single_row_object)}."
),
)
def equivalence(tensor1, tensor2):
return torch.max(torch.abs(tensor1 - tensor2))
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config.output_hidden_states = True
model_name = model_class.__name__
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
batch_size = self.model_tester.batch_size
single_row_input = {}
for key, value in batched_input_prepared.items():
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
# e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
single_batch_shape = value.shape[0] // batch_size
single_row_input[key] = value[:single_batch_shape]
else:
single_row_input[key] = value
with torch.no_grad():
model_batched_output = model(**batched_input_prepared)
model_row_output = model(**single_row_input)
if isinstance(model_batched_output, torch.Tensor):
model_batched_output = {"model_output": model_batched_output}
model_row_output = {"model_output": model_row_output}
for key in model_batched_output:
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
def prepare_imgs():
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
image1 = dataset[0]["image"]
image2 = dataset[1]["image"]
image3 = dataset[2]["image"]
return [[image1, image2], [image3, image2]]
@require_torch
@require_vision
class SuperGlueModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return (
AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
if is_vision_available()
else None
)
@slow
def test_inference(self):
model = SuperGlueForKeypointMatching.from_pretrained("magic-leap-community/superglue_outdoor").to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item()
predicted_matches_values = outputs.matches[0, 0, :30]
predicted_matching_scores_values = outputs.matching_scores[0, 0, :20]
expected_number_of_matches = 282
expected_matches_values = torch.tensor([125,630,137,138,136,143,135,-1,-1,153,
154,156,117,160,-1,149,147,152,168,-1,
165,182,-1,190,187,188,189,112,-1,193],
device=predicted_matches_values.device) # fmt:skip
expected_matching_scores_values = torch.tensor([0.9899,0.0033,0.9897,0.9889,0.9879,0.7464,0.7109,0.0,0.0,0.9841,
0.9889,0.9639,0.0114,0.9559,0.0,0.9735,0.8018,0.5190,0.9157,0.0],
device=predicted_matches_values.device) # fmt:skip
"""
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
specific test example). The consequence of having different number of keypoints is that the number of matches
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
match. The matching scores will also be different, as the keypoints are different. The checks here are less
strict to account for these inconsistencies.
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
Such CUDA inconsistencies can be found
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
"""
self.assertTrue(abs(predicted_number_of_matches - expected_number_of_matches) < 4)
self.assertTrue(
torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4
)
self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4)