mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 14:50:07 +06:00

* Initial commit with template code generated by transformers-cli
* Multiple additions to SuperGlue implementation :
- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass
* Few changes for SuperGlue
* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference
* Reverted unintentional change
* Multiple changes :
- Added SuperGlue to a bunch of places
- Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
- Added testing images
* Moved things in init files
* Added docs (to be finished depending on the final implementation)
* Added necessary imports and some doc
* Removed unnecessary import
* Fixed make fix-copies bug and ran it
* Deleted SuperGlueModel
Fixed convert script
* Added SuperGlueImageProcessor
* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences
* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances
* Added initial tests for SuperGlueImageProcessor
* Added AutoModelForImageMatching in missing places and tests
* Fixed keypoint_detector_output instructions
* Fix style
* Adapted to latest main changes
* Added integration test
* Fixed bugs to pass tests
* Added keypoints returned by keypoint detector in the output of SuperGlue
* Added doc to SuperGlue
* SuperGlue returning all attention and hidden states for a fixed number of keypoints
* Make style
* Changed SuperGlueImageProcessor tests
* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly
This reverts commit 5b3b669c
* Added back hidden_states and attentions masked outputs with tests
* Renamed ImageMatching occurences into KeypointMatching
* Changed SuperGlueImageProcessor to raise error when batch_size is not even
* Added docs and clarity to hidden state and attention grouping function
* Fixed some code and done refactoring
* Fixed typo in SuperPoint output doc
* Fixed some of the formatting and variable naming problems
* Removed useless function call
* Removed AutoModelForKeypointMatching
* Fixed SuperGlueImageProcessor to only accept paris of images
* Added more fixes to SuperGlueImageProcessor
* Simplified the batching of attention and hidden states
* Simplified stack functions
* Moved attention instructions into class
* Removed unused do_batch_norm argument
* Moved weight initialization to the proper place
* Replaced deepcopy for instantiation
* Fixed small bug
* Changed from stevenbucaille to magic-leap repo
* Renamed London Bridge images to Tower Bridge
* Fixed formatting
* Renamed remaining "london" to "tower"
* Apply suggestions from code review
Small changes in the docs
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Added AutoModelForKeypointMatching
* Changed images used in example
* Several changes to image_processing_superglue and style
* Fixed resample type hint
* Changed SuperGlueImageProcessor and added test case for list of 2 images
* Changed list_of_tuples implementation
* Fix in dummy objects
* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring
* Added missing docstring
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Moved forward block at bottom
* Added docstring to forward method
* Added docstring to match_image_pair method
* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature
* Removed AutoModelForKeypointMatching
* Removed image fixtures and added load_dataset
* Added padding of images in SuperGlueImageProcessor
* Cleaned up convert_superglue_to_hf script
* Added missing docs and fixed unused argument
* Fixed SuperGlueImageProcessor tests
* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape
* Added SuperGlueForKeypointMatching back to modeling_auto
* Fixed image processor padding test
* Changed SuperGlue docs
* changes:
- Abstraction to batch, concat and stack of inconsistent tensors
- Changed conv1d's to linears to match standard attention implementations
- Renamed all tensors to be tensor0 and not tensor_0 and be consistent
- Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
- Various changes in docs, etc
* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code
* Formatting changes
* Reverted conv1d to linear conversion because of numerical differences
* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib
* fix: removed unnecessary test
* chore: removed commented code and added back hidden states transpositions
* chore: changed from "inconsistent" to "ragged" function names as suggested
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: applied suggestions
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: updated to display matched output
* chore: applied suggestion for check_image_pairs_input function
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function
* tests: simplified tests for image input format and shapes
* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly
* feat: several changes to address comments
Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts
Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()
Updated test to reflect changes in SuperGlue model
* refactor: changed validate_and_format_image_pairs function with clarity
* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class
* fix: fixed forgotten init weight change from last commit
* fix: fixed rebase mistake
* fix: removed leftover commented code
* fix: added typehint and changed some of arguments default values
* fix: fixed attribute default values for SuperGlueConfig
* feat: added SuperGlueImageProcessor post process keypoint matching method with tests
* fix: fixed SuperGlue attention and hidden state tuples aggregation
* chore: fixed mask optionality and reordered tensor reshapes to be cleaner
* chore: fixed docs and error message returned in validate_and_format_image_pairs function
* fix: fixed returned keypoints to be the ones that SuperPoint returns
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)
* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement
* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules
* WIP: implement Attention from an existing class (like BERT)
* docs: Changed docs to include more appealing matching plot
* WIP: Implement Attention
* chore: minor typehint change
* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation
* Revert "Fixed typo in SuperPoint output doc"
This reverts commit 2120390e82
.
* chore: added comments in SuperGlueImageProcessor
* chore: changed SuperGlue organization HF repo to magic-leap-community
* [run-slow] refactor: small change in layer instantiation
* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community
* [run-slow] chore: make style
* chore: update image matching fixture dataset HF repository
* [run-slow] superglue
* tests: overwriting test_batching_equivalence
* [run-slow] superglue
* tests: changed test to cope with value changing depending on cuda version
* [run-slow] superglue
* tests: changed matching_threshold value
* [run-slow] superglue
* [run-slow] superglue
* tests: changed tests for integration
* [run-slow] superglue
* fix: Changed tensor view and permutations to match original implementation results
* fix: updated convert script and integration test to include last change in model
* fix: increase tolerance for CUDA variances
* Apply suggestions from code review
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
* [run-slow] superglue
* chore: removed blank whitespaces
* [run-slow] superglue
* Revert SuperPoint image processor accident changes
* [run-slow] superglue
* refactor: reverted copy from BERT class
* tests: lower the tolerance in integration tests for SuperGlue
* [run-slow] superglue
* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors
* [run-slow] superglue
* fix: fixed imports in SuperGlue files
* chore: changed do_grayscale SuperGlueImageProcessing default value to True
* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor
* fix: set matching_threshold default value to 0.0 instead of 0.2
* feat: added matching_threshold to post_process_keypoint_matching method
* docs: update superglue.md to include matching_threshold parameter
* docs: updated SuperGlueConfig docstring for matching_threshold default value
* refactor: removed unnecessary parameters in SuperGlueConfig
* fix: changed from matching_threshold to threshold
* fix: re-revert changes to make SuperGlue attention classes copies of BERT
* [run-slow] superglue
* fix: added missing device argument in post_processing method
* [run-slow] superglue
* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)
* fix: add device to image_sizes tensor instantiation
* tests: added checks on do_grayscale test
* chore: reordered and added Optional typehint to KeypointMatchingOutput
* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods
* LightGluePR suggestions:
- use batching to get keypoint detection
* refactor: processing images done in 1 for loop instead of 4
* fix: use @ instead of torch.einsum for scores computation
* style: added #fmt skip to long tensor values
* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones
* refactor: prepare_imgs
* refactor: simplified `validate_and_format_image_pairs`
* docs: fixed doc
---------
Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
428 lines
18 KiB
Python
428 lines
18 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import unittest
|
|
from typing import List
|
|
|
|
from datasets import load_dataset
|
|
|
|
from transformers.models.superglue.configuration_superglue import SuperGlueConfig
|
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import SuperGlueForKeypointMatching
|
|
|
|
if is_vision_available():
|
|
from transformers import AutoImageProcessor
|
|
|
|
|
|
class SuperGlueModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
image_width=80,
|
|
image_height=60,
|
|
keypoint_detector_config=None,
|
|
hidden_size: int = 64,
|
|
keypoint_encoder_sizes: List[int] = [32, 64],
|
|
gnn_layers_types: List[str] = ["self", "cross"] * 2,
|
|
num_attention_heads: int = 4,
|
|
sinkhorn_iterations: int = 100,
|
|
matching_threshold: float = 0.2,
|
|
):
|
|
if keypoint_detector_config is None:
|
|
keypoint_detector_config = {
|
|
"encoder_hidden_sizes": [32, 64],
|
|
"decoder_hidden_size": 64,
|
|
"keypoint_decoder_dim": 65,
|
|
"descriptor_decoder_dim": 64,
|
|
"keypoint_threshold": 0.005,
|
|
"max_keypoints": 256,
|
|
"nms_radius": 4,
|
|
"border_removal_distance": 4,
|
|
}
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.image_width = image_width
|
|
self.image_height = image_height
|
|
|
|
self.keypoint_detector_config = keypoint_detector_config
|
|
self.hidden_size = hidden_size
|
|
self.keypoint_encoder_sizes = keypoint_encoder_sizes
|
|
self.gnn_layers_types = gnn_layers_types
|
|
self.num_attention_heads = num_attention_heads
|
|
self.sinkhorn_iterations = sinkhorn_iterations
|
|
self.matching_threshold = matching_threshold
|
|
|
|
def prepare_config_and_inputs(self):
|
|
# SuperGlue expects a grayscale image as input
|
|
pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width])
|
|
config = self.get_config()
|
|
return config, pixel_values
|
|
|
|
def get_config(self):
|
|
return SuperGlueConfig(
|
|
keypoint_detector_config=self.keypoint_detector_config,
|
|
hidden_size=self.hidden_size,
|
|
keypoint_encoder_sizes=self.keypoint_encoder_sizes,
|
|
gnn_layers_types=self.gnn_layers_types,
|
|
num_attention_heads=self.num_attention_heads,
|
|
sinkhorn_iterations=self.sinkhorn_iterations,
|
|
matching_threshold=self.matching_threshold,
|
|
)
|
|
|
|
def create_and_check_model(self, config, pixel_values):
|
|
model = SuperGlueForKeypointMatching(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(pixel_values)
|
|
maximum_num_matches = result.mask.shape[-1]
|
|
self.parent.assertEqual(
|
|
result.keypoints.shape,
|
|
(self.batch_size, 2, maximum_num_matches, 2),
|
|
)
|
|
self.parent.assertEqual(
|
|
result.matches.shape,
|
|
(self.batch_size, 2, maximum_num_matches),
|
|
)
|
|
self.parent.assertEqual(
|
|
result.matching_scores.shape,
|
|
(self.batch_size, 2, maximum_num_matches),
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, pixel_values = config_and_inputs
|
|
inputs_dict = {"pixel_values": pixel_values}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class SuperGlueModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (SuperGlueForKeypointMatching,) if is_torch_available() else ()
|
|
all_generative_model_classes = () if is_torch_available() else ()
|
|
|
|
fx_compatible = False
|
|
test_pruning = False
|
|
test_resize_embeddings = False
|
|
test_head_masking = False
|
|
has_attentions = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = SuperGlueModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=SuperGlueConfig, has_text_modality=False, hidden_size=64)
|
|
|
|
def test_config(self):
|
|
self.config_tester.create_and_test_config_to_json_string()
|
|
self.config_tester.create_and_test_config_to_json_file()
|
|
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
|
self.config_tester.create_and_test_config_with_num_labels()
|
|
self.config_tester.check_config_can_be_init_without_params()
|
|
self.config_tester.check_config_arguments_init()
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching does not use inputs_embeds")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching does not support input and output embeddings")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching does not use feedforward chunking")
|
|
def test_feed_forward_chunking(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
|
def test_training(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SuperGlue does not output any loss term in the forward pass")
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
pass
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["pixel_values"]
|
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
|
|
|
def test_hidden_states_output(self):
|
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
hidden_states = outputs.hidden_states
|
|
maximum_num_matches = outputs.mask.shape[-1]
|
|
|
|
hidden_states_sizes = (
|
|
self.model_tester.keypoint_encoder_sizes
|
|
+ [self.model_tester.hidden_size]
|
|
+ [self.model_tester.hidden_size, self.model_tester.hidden_size * 2]
|
|
* len(self.model_tester.gnn_layers_types)
|
|
+ [self.model_tester.hidden_size] * 2
|
|
)
|
|
|
|
for i, hidden_states_size in enumerate(hidden_states_sizes):
|
|
self.assertListEqual(
|
|
list(hidden_states[i].shape[-2:]),
|
|
[hidden_states_size, maximum_num_matches],
|
|
)
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
inputs_dict["output_hidden_states"] = True
|
|
check_hidden_states_output(inputs_dict, config, model_class)
|
|
|
|
# check that output_hidden_states also work using config
|
|
del inputs_dict["output_hidden_states"]
|
|
config.output_hidden_states = True
|
|
|
|
check_hidden_states_output(inputs_dict, config, model_class)
|
|
|
|
def test_attention_outputs(self):
|
|
def check_attention_output(inputs_dict, config, model_class):
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
attentions = outputs.attentions
|
|
maximum_num_matches = outputs.mask.shape[-1]
|
|
|
|
expected_attention_shape = [
|
|
self.model_tester.num_attention_heads,
|
|
maximum_num_matches,
|
|
maximum_num_matches,
|
|
]
|
|
|
|
for i, attention in enumerate(attentions):
|
|
self.assertListEqual(
|
|
list(attention.shape[-3:]),
|
|
expected_attention_shape,
|
|
)
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
inputs_dict["output_attentions"] = True
|
|
check_attention_output(inputs_dict, config, model_class)
|
|
|
|
# check that output_hidden_states also work using config
|
|
del inputs_dict["output_attentions"]
|
|
config.output_attentions = True
|
|
|
|
check_attention_output(inputs_dict, config, model_class)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
from_pretrained_ids = ["magic-leap-community/superglue_indoor", "magic-leap-community/superglue_outdoor"]
|
|
for model_name in from_pretrained_ids:
|
|
model = SuperGlueForKeypointMatching.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
def test_forward_labels_should_be_none(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
# Provide an arbitrary sized Tensor as labels to model inputs
|
|
model_inputs["labels"] = torch.rand((128, 128))
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
model(**model_inputs)
|
|
self.assertEqual(ValueError, cm.exception.__class__)
|
|
|
|
def test_batching_equivalence(self):
|
|
"""
|
|
Overwriting ModelTesterMixin.test_batching_equivalence since SuperGlue returns `matching_scores` tensors full of
|
|
zeros which causes the test to fail, because cosine_similarity of two zero tensors is 0.
|
|
Discussed here : https://github.com/huggingface/transformers/pull/29886#issuecomment-2481539481
|
|
"""
|
|
|
|
def recursive_check(batched_object, single_row_object, model_name, key):
|
|
if isinstance(batched_object, (list, tuple)):
|
|
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
|
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
|
elif isinstance(batched_object, dict):
|
|
for batched_object_value, single_row_object_value in zip(
|
|
batched_object.values(), single_row_object.values()
|
|
):
|
|
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
|
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
|
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
|
return
|
|
elif batched_object.dim() == 0:
|
|
return
|
|
else:
|
|
# indexing the first element does not always work
|
|
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
|
|
slice_ids = [slice(0, index) for index in single_row_object.shape]
|
|
batched_row = batched_object[slice_ids]
|
|
self.assertFalse(
|
|
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
|
|
)
|
|
self.assertFalse(
|
|
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
|
|
)
|
|
self.assertFalse(
|
|
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
|
|
)
|
|
self.assertFalse(
|
|
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
|
)
|
|
self.assertTrue(
|
|
(equivalence(batched_row, single_row_object)) <= 1e-03,
|
|
msg=(
|
|
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
|
f"Difference={equivalence(batched_row, single_row_object)}."
|
|
),
|
|
)
|
|
|
|
def equivalence(tensor1, tensor2):
|
|
return torch.max(torch.abs(tensor1 - tensor2))
|
|
|
|
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
config.output_hidden_states = True
|
|
|
|
model_name = model_class.__name__
|
|
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
|
model = model_class(config).to(torch_device).eval()
|
|
|
|
batch_size = self.model_tester.batch_size
|
|
single_row_input = {}
|
|
for key, value in batched_input_prepared.items():
|
|
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
|
|
# e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
|
|
single_batch_shape = value.shape[0] // batch_size
|
|
single_row_input[key] = value[:single_batch_shape]
|
|
else:
|
|
single_row_input[key] = value
|
|
|
|
with torch.no_grad():
|
|
model_batched_output = model(**batched_input_prepared)
|
|
model_row_output = model(**single_row_input)
|
|
|
|
if isinstance(model_batched_output, torch.Tensor):
|
|
model_batched_output = {"model_output": model_batched_output}
|
|
model_row_output = {"model_output": model_row_output}
|
|
|
|
for key in model_batched_output:
|
|
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
|
|
|
|
|
def prepare_imgs():
|
|
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
|
|
image1 = dataset[0]["image"]
|
|
image2 = dataset[1]["image"]
|
|
image3 = dataset[2]["image"]
|
|
return [[image1, image2], [image3, image2]]
|
|
|
|
|
|
@require_torch
|
|
@require_vision
|
|
class SuperGlueModelIntegrationTest(unittest.TestCase):
|
|
@cached_property
|
|
def default_image_processor(self):
|
|
return (
|
|
AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
|
if is_vision_available()
|
|
else None
|
|
)
|
|
|
|
@slow
|
|
def test_inference(self):
|
|
model = SuperGlueForKeypointMatching.from_pretrained("magic-leap-community/superglue_outdoor").to(torch_device)
|
|
preprocessor = self.default_image_processor
|
|
images = prepare_imgs()
|
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
|
with torch.no_grad():
|
|
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
|
|
|
predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item()
|
|
predicted_matches_values = outputs.matches[0, 0, :30]
|
|
predicted_matching_scores_values = outputs.matching_scores[0, 0, :20]
|
|
|
|
expected_number_of_matches = 282
|
|
expected_matches_values = torch.tensor([125,630,137,138,136,143,135,-1,-1,153,
|
|
154,156,117,160,-1,149,147,152,168,-1,
|
|
165,182,-1,190,187,188,189,112,-1,193],
|
|
device=predicted_matches_values.device) # fmt:skip
|
|
expected_matching_scores_values = torch.tensor([0.9899,0.0033,0.9897,0.9889,0.9879,0.7464,0.7109,0.0,0.0,0.9841,
|
|
0.9889,0.9639,0.0114,0.9559,0.0,0.9735,0.8018,0.5190,0.9157,0.0],
|
|
device=predicted_matches_values.device) # fmt:skip
|
|
|
|
"""
|
|
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
|
|
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
|
|
specific test example). The consequence of having different number of keypoints is that the number of matches
|
|
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
|
|
match. The matching scores will also be different, as the keypoints are different. The checks here are less
|
|
strict to account for these inconsistencies.
|
|
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
|
|
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
|
|
|
|
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
|
|
Such CUDA inconsistencies can be found
|
|
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
|
|
"""
|
|
|
|
self.assertTrue(abs(predicted_number_of_matches - expected_number_of_matches) < 4)
|
|
self.assertTrue(
|
|
torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4
|
|
)
|
|
self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4)
|