mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add SuperGlue model (#29886)
* Initial commit with template code generated by transformers-cli
* Multiple additions to SuperGlue implementation :
- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass
* Few changes for SuperGlue
* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference
* Reverted unintentional change
* Multiple changes :
- Added SuperGlue to a bunch of places
- Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
- Added testing images
* Moved things in init files
* Added docs (to be finished depending on the final implementation)
* Added necessary imports and some doc
* Removed unnecessary import
* Fixed make fix-copies bug and ran it
* Deleted SuperGlueModel
Fixed convert script
* Added SuperGlueImageProcessor
* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences
* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances
* Added initial tests for SuperGlueImageProcessor
* Added AutoModelForImageMatching in missing places and tests
* Fixed keypoint_detector_output instructions
* Fix style
* Adapted to latest main changes
* Added integration test
* Fixed bugs to pass tests
* Added keypoints returned by keypoint detector in the output of SuperGlue
* Added doc to SuperGlue
* SuperGlue returning all attention and hidden states for a fixed number of keypoints
* Make style
* Changed SuperGlueImageProcessor tests
* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly
This reverts commit 5b3b669c
* Added back hidden_states and attentions masked outputs with tests
* Renamed ImageMatching occurences into KeypointMatching
* Changed SuperGlueImageProcessor to raise error when batch_size is not even
* Added docs and clarity to hidden state and attention grouping function
* Fixed some code and done refactoring
* Fixed typo in SuperPoint output doc
* Fixed some of the formatting and variable naming problems
* Removed useless function call
* Removed AutoModelForKeypointMatching
* Fixed SuperGlueImageProcessor to only accept paris of images
* Added more fixes to SuperGlueImageProcessor
* Simplified the batching of attention and hidden states
* Simplified stack functions
* Moved attention instructions into class
* Removed unused do_batch_norm argument
* Moved weight initialization to the proper place
* Replaced deepcopy for instantiation
* Fixed small bug
* Changed from stevenbucaille to magic-leap repo
* Renamed London Bridge images to Tower Bridge
* Fixed formatting
* Renamed remaining "london" to "tower"
* Apply suggestions from code review
Small changes in the docs
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Added AutoModelForKeypointMatching
* Changed images used in example
* Several changes to image_processing_superglue and style
* Fixed resample type hint
* Changed SuperGlueImageProcessor and added test case for list of 2 images
* Changed list_of_tuples implementation
* Fix in dummy objects
* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring
* Added missing docstring
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Moved forward block at bottom
* Added docstring to forward method
* Added docstring to match_image_pair method
* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature
* Removed AutoModelForKeypointMatching
* Removed image fixtures and added load_dataset
* Added padding of images in SuperGlueImageProcessor
* Cleaned up convert_superglue_to_hf script
* Added missing docs and fixed unused argument
* Fixed SuperGlueImageProcessor tests
* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape
* Added SuperGlueForKeypointMatching back to modeling_auto
* Fixed image processor padding test
* Changed SuperGlue docs
* changes:
- Abstraction to batch, concat and stack of inconsistent tensors
- Changed conv1d's to linears to match standard attention implementations
- Renamed all tensors to be tensor0 and not tensor_0 and be consistent
- Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
- Various changes in docs, etc
* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code
* Formatting changes
* Reverted conv1d to linear conversion because of numerical differences
* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib
* fix: removed unnecessary test
* chore: removed commented code and added back hidden states transpositions
* chore: changed from "inconsistent" to "ragged" function names as suggested
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: applied suggestions
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: updated to display matched output
* chore: applied suggestion for check_image_pairs_input function
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function
* tests: simplified tests for image input format and shapes
* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly
* feat: several changes to address comments
Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts
Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()
Updated test to reflect changes in SuperGlue model
* refactor: changed validate_and_format_image_pairs function with clarity
* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class
* fix: fixed forgotten init weight change from last commit
* fix: fixed rebase mistake
* fix: removed leftover commented code
* fix: added typehint and changed some of arguments default values
* fix: fixed attribute default values for SuperGlueConfig
* feat: added SuperGlueImageProcessor post process keypoint matching method with tests
* fix: fixed SuperGlue attention and hidden state tuples aggregation
* chore: fixed mask optionality and reordered tensor reshapes to be cleaner
* chore: fixed docs and error message returned in validate_and_format_image_pairs function
* fix: fixed returned keypoints to be the ones that SuperPoint returns
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)
* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement
* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules
* WIP: implement Attention from an existing class (like BERT)
* docs: Changed docs to include more appealing matching plot
* WIP: Implement Attention
* chore: minor typehint change
* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation
* Revert "Fixed typo in SuperPoint output doc"
This reverts commit 2120390e82
.
* chore: added comments in SuperGlueImageProcessor
* chore: changed SuperGlue organization HF repo to magic-leap-community
* [run-slow] refactor: small change in layer instantiation
* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community
* [run-slow] chore: make style
* chore: update image matching fixture dataset HF repository
* [run-slow] superglue
* tests: overwriting test_batching_equivalence
* [run-slow] superglue
* tests: changed test to cope with value changing depending on cuda version
* [run-slow] superglue
* tests: changed matching_threshold value
* [run-slow] superglue
* [run-slow] superglue
* tests: changed tests for integration
* [run-slow] superglue
* fix: Changed tensor view and permutations to match original implementation results
* fix: updated convert script and integration test to include last change in model
* fix: increase tolerance for CUDA variances
* Apply suggestions from code review
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
* [run-slow] superglue
* chore: removed blank whitespaces
* [run-slow] superglue
* Revert SuperPoint image processor accident changes
* [run-slow] superglue
* refactor: reverted copy from BERT class
* tests: lower the tolerance in integration tests for SuperGlue
* [run-slow] superglue
* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors
* [run-slow] superglue
* fix: fixed imports in SuperGlue files
* chore: changed do_grayscale SuperGlueImageProcessing default value to True
* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor
* fix: set matching_threshold default value to 0.0 instead of 0.2
* feat: added matching_threshold to post_process_keypoint_matching method
* docs: update superglue.md to include matching_threshold parameter
* docs: updated SuperGlueConfig docstring for matching_threshold default value
* refactor: removed unnecessary parameters in SuperGlueConfig
* fix: changed from matching_threshold to threshold
* fix: re-revert changes to make SuperGlue attention classes copies of BERT
* [run-slow] superglue
* fix: added missing device argument in post_processing method
* [run-slow] superglue
* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)
* fix: add device to image_sizes tensor instantiation
* tests: added checks on do_grayscale test
* chore: reordered and added Optional typehint to KeypointMatchingOutput
* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods
* LightGluePR suggestions:
- use batching to get keypoint detection
* refactor: processing images done in 1 for loop instead of 4
* fix: use @ instead of torch.einsum for scores computation
* style: added #fmt skip to long tensor values
* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones
* refactor: prepare_imgs
* refactor: simplified `validate_and_format_image_pairs`
* docs: fixed doc
---------
Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
872dfbdd46
commit
abe57b6f17
@ -713,6 +713,8 @@
|
||||
title: SegFormer
|
||||
- local: model_doc/seggpt
|
||||
title: SegGpt
|
||||
- local: model_doc/superglue
|
||||
title: SuperGlue
|
||||
- local: model_doc/superpoint
|
||||
title: SuperPoint
|
||||
- local: model_doc/swiftformer
|
||||
|
@ -318,6 +318,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [SqueezeBERT](model_doc/squeezebert) | ✅ | ❌ | ❌ |
|
||||
| [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ |
|
||||
| [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ |
|
||||
| [SuperGlue](model_doc/superglue) | ✅ | ❌ | ❌ |
|
||||
| [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ |
|
||||
| [SwiftFormer](model_doc/swiftformer) | ✅ | ✅ | ❌ |
|
||||
| [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ |
|
||||
|
138
docs/source/en/model_doc/superglue.md
Normal file
138
docs/source/en/model_doc/superglue.md
Normal file
@ -0,0 +1,138 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the MIT License; you may not use this file except in compliance with
|
||||
the License.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
|
||||
-->
|
||||
|
||||
# SuperGlue
|
||||
|
||||
## Overview
|
||||
|
||||
The SuperGlue model was proposed in [SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763) by Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich.
|
||||
|
||||
This model consists of matching two sets of interest points detected in an image. Paired with the
|
||||
[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and
|
||||
estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper introduces SuperGlue, a neural network that matches two sets of local features by jointly finding correspondences
|
||||
and rejecting non-matchable points. Assignments are estimated by solving a differentiable optimal transport problem, whose costs
|
||||
are predicted by a graph neural network. We introduce a flexible context aggregation mechanism based on attention, enabling
|
||||
SuperGlue to reason about the underlying 3D scene and feature assignments jointly. Compared to traditional, hand-designed heuristics,
|
||||
our technique learns priors over geometric transformations and regularities of the 3D world through end-to-end training from image
|
||||
pairs. SuperGlue outperforms other learned approaches and achieves state-of-the-art results on the task of pose estimation in
|
||||
challenging real-world indoor and outdoor environments. The proposed method performs matching in real-time on a modern GPU and
|
||||
can be readily integrated into modern SfM or SLAM systems. The code and trained weights are publicly available at this [URL](https://github.com/magicleap/SuperGluePretrainedNetwork).*
|
||||
|
||||
## How to use
|
||||
|
||||
Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched.
|
||||
The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding
|
||||
matching scores.
|
||||
```python
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
|
||||
image1 = Image.open(requests.get(url_image1, stream=True).raw)
|
||||
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
|
||||
image_2 = Image.open(requests.get(url_image2, stream=True).raw)
|
||||
|
||||
images = [image1, image2]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
|
||||
inputs = processor(images, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
You can use the `post_process_keypoint_matching` method from the `SuperGlueImageProcessor` to get the keypoints and matches in a more readable format:
|
||||
|
||||
```python
|
||||
image_sizes = [[(image.height, image.width) for image in images]]
|
||||
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
||||
for i, output in enumerate(outputs):
|
||||
print("For the image pair", i)
|
||||
for keypoint0, keypoint1, matching_score in zip(
|
||||
output["keypoints0"], output["keypoints1"], output["matching_scores"]
|
||||
):
|
||||
print(
|
||||
f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
From the outputs, you can visualize the matches between the two images using the following code:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Create side by side image
|
||||
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
|
||||
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
|
||||
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
|
||||
plt.imshow(merged_image)
|
||||
plt.axis("off")
|
||||
|
||||
# Retrieve the keypoints and matches
|
||||
output = outputs[0]
|
||||
keypoints0 = output["keypoints0"]
|
||||
keypoints1 = output["keypoints1"]
|
||||
matching_scores = output["matching_scores"]
|
||||
keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy()
|
||||
keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy()
|
||||
|
||||
# Plot the matches
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores
|
||||
):
|
||||
plt.plot(
|
||||
[keypoint0_x, keypoint1_x + image1.width],
|
||||
[keypoint0_y, keypoint1_y],
|
||||
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||
alpha=0.9,
|
||||
linewidth=0.5,
|
||||
)
|
||||
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
|
||||
plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2)
|
||||
|
||||
# Save the plot
|
||||
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
```
|
||||
|
||||

|
||||
|
||||
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
||||
The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork).
|
||||
|
||||
## SuperGlueConfig
|
||||
|
||||
[[autodoc]] SuperGlueConfig
|
||||
|
||||
## SuperGlueImageProcessor
|
||||
|
||||
[[autodoc]] SuperGlueImageProcessor
|
||||
|
||||
- preprocess
|
||||
|
||||
## SuperGlueForKeypointMatching
|
||||
|
||||
[[autodoc]] SuperGlueForKeypointMatching
|
||||
|
||||
- forward
|
||||
- post_process_keypoint_matching
|
@ -785,6 +785,7 @@ _import_structure = {
|
||||
],
|
||||
"models.stablelm": ["StableLmConfig"],
|
||||
"models.starcoder2": ["Starcoder2Config"],
|
||||
"models.superglue": ["SuperGlueConfig"],
|
||||
"models.superpoint": ["SuperPointConfig"],
|
||||
"models.swiftformer": ["SwiftFormerConfig"],
|
||||
"models.swin": ["SwinConfig"],
|
||||
@ -1268,6 +1269,7 @@ else:
|
||||
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
|
||||
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
|
||||
_import_structure["models.siglip"].append("SiglipImageProcessor")
|
||||
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
|
||||
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
|
||||
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
|
||||
_import_structure["models.textnet"].extend(["TextNetImageProcessor"])
|
||||
@ -3545,6 +3547,12 @@ else:
|
||||
"Starcoder2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.superglue"].extend(
|
||||
[
|
||||
"SuperGlueForKeypointMatching",
|
||||
"SuperGluePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.superpoint"].extend(
|
||||
[
|
||||
"SuperPointForKeypointDetection",
|
||||
@ -5861,6 +5869,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.stablelm import StableLmConfig
|
||||
from .models.starcoder2 import Starcoder2Config
|
||||
from .models.superglue import SuperGlueConfig
|
||||
from .models.superpoint import SuperPointConfig
|
||||
from .models.swiftformer import (
|
||||
SwiftFormerConfig,
|
||||
@ -6361,6 +6370,7 @@ if TYPE_CHECKING:
|
||||
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
|
||||
from .models.seggpt import SegGptImageProcessor
|
||||
from .models.siglip import SiglipImageProcessor
|
||||
from .models.superglue import SuperGlueImageProcessor
|
||||
from .models.superpoint import SuperPointImageProcessor
|
||||
from .models.swin2sr import Swin2SRImageProcessor
|
||||
from .models.textnet import TextNetImageProcessor
|
||||
@ -8186,6 +8196,10 @@ if TYPE_CHECKING:
|
||||
Starcoder2Model,
|
||||
Starcoder2PreTrainedModel,
|
||||
)
|
||||
from .models.superglue import (
|
||||
SuperGlueForKeypointMatching,
|
||||
SuperGluePreTrainedModel,
|
||||
)
|
||||
from .models.superpoint import (
|
||||
SuperPointForKeypointDetection,
|
||||
SuperPointPreTrainedModel,
|
||||
|
@ -246,6 +246,7 @@ from . import (
|
||||
squeezebert,
|
||||
stablelm,
|
||||
starcoder2,
|
||||
superglue,
|
||||
superpoint,
|
||||
swiftformer,
|
||||
swin,
|
||||
|
@ -273,6 +273,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("stablelm", "StableLmConfig"),
|
||||
("starcoder2", "Starcoder2Config"),
|
||||
("superglue", "SuperGlueConfig"),
|
||||
("superpoint", "SuperPointConfig"),
|
||||
("swiftformer", "SwiftFormerConfig"),
|
||||
("swin", "SwinConfig"),
|
||||
@ -608,6 +609,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("stablelm", "StableLm"),
|
||||
("starcoder2", "Starcoder2"),
|
||||
("superglue", "SuperGlue"),
|
||||
("superpoint", "SuperPoint"),
|
||||
("swiftformer", "SwiftFormer"),
|
||||
("swin", "Swin Transformer"),
|
||||
|
@ -133,6 +133,7 @@ else:
|
||||
("segformer", ("SegformerImageProcessor",)),
|
||||
("seggpt", ("SegGptImageProcessor",)),
|
||||
("siglip", ("SiglipImageProcessor",)),
|
||||
("superglue", "SuperGlueImageProcessor"),
|
||||
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("swin2sr", ("Swin2SRImageProcessor",)),
|
||||
|
@ -251,6 +251,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("stablelm", "StableLmModel"),
|
||||
("starcoder2", "Starcoder2Model"),
|
||||
("superglue", "SuperGlueForKeypointMatching"),
|
||||
("swiftformer", "SwiftFormerModel"),
|
||||
("swin", "SwinModel"),
|
||||
("swin2sr", "Swin2SRModel"),
|
||||
|
28
src/transformers/models/superglue/__init__.py
Normal file
28
src/transformers/models/superglue/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_superglue import *
|
||||
from .image_processing_superglue import *
|
||||
from .modeling_superglue import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
120
src/transformers/models/superglue/configuration_superglue.py
Normal file
120
src/transformers/models/superglue/configuration_superglue.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..superpoint import SuperPointConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SuperGlueConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SuperGlueModel`]. It is used to instantiate a
|
||||
SuperGlue model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the SuperGlue
|
||||
[magic-leap-community/superglue_indoor](https://huggingface.co/magic-leap-community/superglue_indoor) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
|
||||
The config object or dictionary of the keypoint detector.
|
||||
hidden_size (`int`, *optional*, defaults to 256):
|
||||
The dimension of the descriptors.
|
||||
keypoint_encoder_sizes (`List[int]`, *optional*, defaults to `[32, 64, 128, 256]`):
|
||||
The sizes of the keypoint encoder layers.
|
||||
gnn_layers_types (`List[str]`, *optional*, defaults to `['self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross']`):
|
||||
The types of the GNN layers. Must be either 'self' or 'cross'.
|
||||
num_attention_heads (`int`, *optional*, defaults to 4):
|
||||
The number of heads in the GNN layers.
|
||||
sinkhorn_iterations (`int`, *optional*, defaults to 100):
|
||||
The number of Sinkhorn iterations.
|
||||
matching_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The matching threshold.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import SuperGlueConfig, SuperGlueModel
|
||||
|
||||
>>> # Initializing a SuperGlue superglue style configuration
|
||||
>>> configuration = SuperGlueConfig()
|
||||
|
||||
>>> # Initializing a model from the superglue style configuration
|
||||
>>> model = SuperGlueModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "superglue"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keypoint_detector_config: "SuperPointConfig" = None,
|
||||
hidden_size: int = 256,
|
||||
keypoint_encoder_sizes: List[int] = None,
|
||||
gnn_layers_types: List[str] = None,
|
||||
num_attention_heads: int = 4,
|
||||
sinkhorn_iterations: int = 100,
|
||||
matching_threshold: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
self.gnn_layers_types = gnn_layers_types if gnn_layers_types is not None else ["self", "cross"] * 9
|
||||
# Check whether all gnn_layers_types are either 'self' or 'cross'
|
||||
if not all(layer_type in ["self", "cross"] for layer_type in self.gnn_layers_types):
|
||||
raise ValueError("All gnn_layers_types must be either 'self' or 'cross'")
|
||||
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError("hidden_size % num_attention_heads is different from zero")
|
||||
|
||||
self.keypoint_encoder_sizes = (
|
||||
keypoint_encoder_sizes if keypoint_encoder_sizes is not None else [32, 64, 128, 256]
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.keypoint_encoder_sizes = keypoint_encoder_sizes
|
||||
self.gnn_layers_types = gnn_layers_types
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sinkhorn_iterations = sinkhorn_iterations
|
||||
self.matching_threshold = matching_threshold
|
||||
|
||||
if isinstance(keypoint_detector_config, dict):
|
||||
keypoint_detector_config["model_type"] = (
|
||||
keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint"
|
||||
)
|
||||
keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
|
||||
**keypoint_detector_config
|
||||
)
|
||||
if keypoint_detector_config is None:
|
||||
keypoint_detector_config = CONFIG_MAPPING["superpoint"]()
|
||||
|
||||
self.keypoint_detector_config = keypoint_detector_config
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_probs_dropout_prob = 0
|
||||
self.is_decoder = False
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["SuperGlueConfig"]
|
342
src/transformers/models/superglue/convert_superglue_to_hf.py
Normal file
342
src/transformers/models/superglue/convert_superglue_to_hf.py
Normal file
@ -0,0 +1,342 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
AutoModelForKeypointDetection,
|
||||
SuperGlueConfig,
|
||||
SuperGlueForKeypointMatching,
|
||||
SuperGlueImageProcessor,
|
||||
)
|
||||
|
||||
|
||||
def prepare_imgs():
|
||||
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
|
||||
image1 = dataset[0]["image"]
|
||||
image2 = dataset[1]["image"]
|
||||
image3 = dataset[2]["image"]
|
||||
return [[image1, image2], [image3, image2]]
|
||||
|
||||
|
||||
def verify_model_outputs(model, model_name, device):
|
||||
images = prepare_imgs()
|
||||
preprocessor = SuperGlueImageProcessor()
|
||||
inputs = preprocessor(images=images, return_tensors="pt").to(device)
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||
|
||||
predicted_matches_values = outputs.matches[0, 0, :10]
|
||||
predicted_matching_scores_values = outputs.matching_scores[0, 0, :10]
|
||||
|
||||
predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item()
|
||||
|
||||
if "outdoor" in model_name:
|
||||
expected_max_number_keypoints = 865
|
||||
expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||
expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||
|
||||
expected_matches_values = torch.tensor(
|
||||
[125, 630, 137, 138, 136, 143, 135, -1, -1, 153], dtype=torch.int64, device=device
|
||||
)
|
||||
expected_matching_scores_values = torch.tensor(
|
||||
[0.9899, 0.0033, 0.9897, 0.9889, 0.9879, 0.7464, 0.7109, 0, 0, 0.9841], device=device
|
||||
)
|
||||
|
||||
expected_number_of_matches = 281
|
||||
elif "indoor" in model_name:
|
||||
expected_max_number_keypoints = 865
|
||||
expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||
expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||
|
||||
expected_matches_values = torch.tensor(
|
||||
[125, 144, 137, 138, 136, 155, 135, -1, -1, 153], dtype=torch.int64, device=device
|
||||
)
|
||||
expected_matching_scores_values = torch.tensor(
|
||||
[0.9694, 0.0010, 0.9006, 0.8753, 0.8521, 0.5688, 0.6321, 0.0, 0.0, 0.7235], device=device
|
||||
)
|
||||
|
||||
expected_number_of_matches = 282
|
||||
|
||||
assert outputs.matches.shape == expected_matches_shape
|
||||
assert outputs.matching_scores.shape == expected_matching_scores_shape
|
||||
|
||||
assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-4)
|
||||
assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-4)
|
||||
|
||||
assert predicted_number_of_matches == expected_number_of_matches
|
||||
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"kenc.encoder.(\d+)": r"keypoint_encoder.encoder.\1.old",
|
||||
r"gnn.layers.(\d+).attn.proj.0": r"gnn.layers.\1.attention.self.query",
|
||||
r"gnn.layers.(\d+).attn.proj.1": r"gnn.layers.\1.attention.self.key",
|
||||
r"gnn.layers.(\d+).attn.proj.2": r"gnn.layers.\1.attention.self.value",
|
||||
r"gnn.layers.(\d+).attn.merge": r"gnn.layers.\1.attention.output.dense",
|
||||
r"gnn.layers.(\d+).mlp.0": r"gnn.layers.\1.mlp.0.linear",
|
||||
r"gnn.layers.(\d+).mlp.1": r"gnn.layers.\1.mlp.0.batch_norm",
|
||||
r"gnn.layers.(\d+).mlp.3": r"gnn.layers.\1.mlp.1",
|
||||
r"final_proj": r"final_projection.final_proj",
|
||||
}
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: List[str], conversion_mapping=ORIGINAL_TO_CONVERTED_KEY_MAPPING):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in conversion_mapping.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def replace_state_dict_keys(all_keys, new_keys, original_state_dict):
|
||||
state_dict = {}
|
||||
for key in all_keys:
|
||||
new_key = new_keys[key]
|
||||
state_dict[new_key] = original_state_dict.pop(key).contiguous().clone()
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict(state_dict, config):
|
||||
converted_to_final_key_mapping = {}
|
||||
|
||||
def convert_conv_to_linear(keys):
|
||||
for key in keys:
|
||||
state_dict[key] = state_dict[key].squeeze(-1)
|
||||
|
||||
def qkv_permute_weights_and_biases(keys, num_heads=4):
|
||||
for key in keys:
|
||||
tensor = state_dict[key]
|
||||
shape = tensor.shape
|
||||
dim_out = shape[0]
|
||||
if len(shape) == 2:
|
||||
dim_in = shape[1]
|
||||
tensor = (
|
||||
tensor.reshape(dim_out // num_heads, num_heads, dim_in).permute(1, 0, 2).reshape(dim_out, dim_in)
|
||||
)
|
||||
if len(shape) == 1:
|
||||
tensor = tensor.reshape(dim_out // num_heads, num_heads).permute(1, 0).reshape(dim_out)
|
||||
state_dict[key] = tensor
|
||||
|
||||
def output_permute_weights(keys, num_heads=4):
|
||||
for key in keys:
|
||||
tensor = state_dict[key]
|
||||
dim_in = tensor.shape[1]
|
||||
dim_out = tensor.shape[0]
|
||||
tensor = tensor.reshape(dim_out, dim_in // num_heads, num_heads).permute(0, 2, 1).reshape(dim_out, dim_in)
|
||||
state_dict[key] = tensor
|
||||
|
||||
conv_keys = []
|
||||
qkv_permute_keys = []
|
||||
output_permute_keys = []
|
||||
# Keypoint Encoder
|
||||
keypoint_encoder_key = "keypoint_encoder.encoder"
|
||||
for i in range(1, len(config.keypoint_encoder_sizes) + 2):
|
||||
old_conv_key = f"{keypoint_encoder_key}.{(i - 1) * 3}.old"
|
||||
new_index = i - 1
|
||||
new_conv_key = f"{keypoint_encoder_key}.{new_index}."
|
||||
if i < len(config.keypoint_encoder_sizes) + 1:
|
||||
new_conv_key = f"{new_conv_key}linear."
|
||||
converted_to_final_key_mapping[rf"{old_conv_key}\."] = new_conv_key
|
||||
if i < len(config.keypoint_encoder_sizes) + 1:
|
||||
old_batch_norm_key = f"{keypoint_encoder_key}.{(i - 1) * 3 + 1}.old"
|
||||
new_batch_norm_key = f"{keypoint_encoder_key}.{new_index}.batch_norm."
|
||||
converted_to_final_key_mapping[rf"{old_batch_norm_key}\."] = new_batch_norm_key
|
||||
|
||||
conv_keys.append(f"{old_conv_key}.weight")
|
||||
|
||||
# Attentional GNN
|
||||
for i in range(len(config.gnn_layers_types)):
|
||||
gnn_layer_key = f"gnn.layers.{i}"
|
||||
## Attention
|
||||
attention_key = f"{gnn_layer_key}.attention"
|
||||
conv_keys.extend(
|
||||
[
|
||||
f"{attention_key}.self.query.weight",
|
||||
f"{attention_key}.self.key.weight",
|
||||
f"{attention_key}.self.value.weight",
|
||||
f"{attention_key}.output.dense.weight",
|
||||
]
|
||||
)
|
||||
qkv_permute_keys.extend(
|
||||
[
|
||||
f"{attention_key}.self.query.weight",
|
||||
f"{attention_key}.self.key.weight",
|
||||
f"{attention_key}.self.value.weight",
|
||||
f"{attention_key}.self.query.bias",
|
||||
f"{attention_key}.self.key.bias",
|
||||
f"{attention_key}.self.value.bias",
|
||||
]
|
||||
)
|
||||
output_permute_keys.append(f"{attention_key}.output.dense.weight")
|
||||
|
||||
## MLP
|
||||
conv_keys.extend([f"{gnn_layer_key}.mlp.0.linear.weight", f"{gnn_layer_key}.mlp.1.weight"])
|
||||
|
||||
# Final Projection
|
||||
conv_keys.append("final_projection.final_proj.weight")
|
||||
|
||||
convert_conv_to_linear(conv_keys)
|
||||
qkv_permute_weights_and_biases(qkv_permute_keys)
|
||||
output_permute_weights(output_permute_keys)
|
||||
all_keys = list(state_dict.keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys, converted_to_final_key_mapping)
|
||||
state_dict = replace_state_dict_keys(all_keys, new_keys, state_dict)
|
||||
return state_dict
|
||||
|
||||
|
||||
def add_keypoint_detector_state_dict(superglue_state_dict):
|
||||
keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
|
||||
keypoint_detector_state_dict = keypoint_detector.state_dict()
|
||||
for k, v in keypoint_detector_state_dict.items():
|
||||
superglue_state_dict[f"keypoint_detector.{k}"] = v
|
||||
return superglue_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def write_model(
|
||||
model_path,
|
||||
checkpoint_url,
|
||||
safe_serialization=True,
|
||||
push_to_hub=False,
|
||||
):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# SuperGlue config
|
||||
# ------------------------------------------------------------
|
||||
|
||||
config = SuperGlueConfig(
|
||||
hidden_size=256,
|
||||
keypoint_encoder_sizes=[32, 64, 128, 256],
|
||||
gnn_layers_types=["self", "cross"] * 9,
|
||||
sinkhorn_iterations=100,
|
||||
matching_threshold=0.0,
|
||||
)
|
||||
config.architectures = ["SuperGlueForKeypointMatching"]
|
||||
config.save_pretrained(model_path, push_to_hub=push_to_hub)
|
||||
print("Model config saved successfully...")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Convert weights
|
||||
# ------------------------------------------------------------
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...")
|
||||
original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
|
||||
|
||||
print("Converting model...")
|
||||
all_keys = list(original_state_dict.keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||
|
||||
state_dict = replace_state_dict_keys(all_keys, new_keys, original_state_dict)
|
||||
state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
del original_state_dict
|
||||
gc.collect()
|
||||
state_dict = add_keypoint_detector_state_dict(state_dict)
|
||||
|
||||
print("Loading the checkpoint in a SuperGlue model...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
with torch.device(device):
|
||||
model = SuperGlueForKeypointMatching(config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
print("Checkpoint loaded successfully...")
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving the model...")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
model = SuperGlueForKeypointMatching.from_pretrained(model_path)
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
model_name = "superglue"
|
||||
if "superglue_outdoor.pth" in checkpoint_url:
|
||||
model_name += "_outdoor"
|
||||
elif "superglue_indoor.pth" in checkpoint_url:
|
||||
model_name += "_indoor"
|
||||
|
||||
print("Checking the model outputs...")
|
||||
verify_model_outputs(model, model_name, device)
|
||||
print("Model outputs verified successfully.")
|
||||
|
||||
organization = "magic-leap-community"
|
||||
if push_to_hub:
|
||||
print("Pushing model to the hub...")
|
||||
model.push_to_hub(
|
||||
repo_id=f"{organization}/{model_name}",
|
||||
commit_message="Add model",
|
||||
)
|
||||
|
||||
write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub)
|
||||
|
||||
|
||||
def write_image_processor(save_dir, model_name, organization, push_to_hub=False):
|
||||
image_processor = SuperGlueImageProcessor()
|
||||
image_processor.save_pretrained(save_dir)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing image processor to the hub...")
|
||||
image_processor.push_to_hub(
|
||||
repo_id=f"{organization}/{model_name}",
|
||||
commit_message="Add image processor",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/weights/superglue_indoor.pth",
|
||||
type=str,
|
||||
help="URL of the original SuperGlue checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Push model and image preprocessor to the hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
args.pytorch_dump_folder_path, args.checkpoint_url, safe_serialization=True, push_to_hub=args.push_to_hub
|
||||
)
|
407
src/transformers/models/superglue/image_processing_superglue.py
Normal file
407
src/transformers/models/superglue/image_processing_superglue.py
Normal file
@ -0,0 +1,407 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for SuperPoint."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import is_torch_available, is_vision_available
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
ImageType,
|
||||
PILImageResampling,
|
||||
get_image_type,
|
||||
infer_channel_dimension_format,
|
||||
is_pil_image,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_superglue import KeypointMatchingOutput
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.superpoint.image_processing_superpoint.is_grayscale
|
||||
def is_grayscale(
|
||||
image: ImageInput,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
if image.shape[0] == 1:
|
||||
return True
|
||||
return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
if image.shape[-1] == 1:
|
||||
return True
|
||||
return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
|
||||
|
||||
|
||||
# Copied from transformers.models.superpoint.image_processing_superpoint.convert_to_grayscale
|
||||
def convert_to_grayscale(
|
||||
image: ImageInput,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch
|
||||
and tensorflow grayscale conversion
|
||||
|
||||
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
|
||||
channel, because of an issue that is discussed in :
|
||||
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
|
||||
|
||||
Args:
|
||||
image (Image):
|
||||
The image to convert.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image.
|
||||
"""
|
||||
requires_backends(convert_to_grayscale, ["vision"])
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
if is_grayscale(image, input_data_format=input_data_format):
|
||||
return image
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
|
||||
gray_image = np.stack([gray_image] * 3, axis=0)
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140
|
||||
gray_image = np.stack([gray_image] * 3, axis=-1)
|
||||
return gray_image
|
||||
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
image = image.convert("L")
|
||||
return image
|
||||
|
||||
|
||||
def validate_and_format_image_pairs(images: ImageInput):
|
||||
error_message = (
|
||||
"Input images must be a one of the following :",
|
||||
" - A pair of PIL images.",
|
||||
" - A pair of 3D arrays.",
|
||||
" - A list of pairs of PIL images.",
|
||||
" - A list of pairs of 3D arrays.",
|
||||
)
|
||||
|
||||
def _is_valid_image(image):
|
||||
"""images is a PIL Image or a 3D array."""
|
||||
return is_pil_image(image) or (
|
||||
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
|
||||
)
|
||||
|
||||
if isinstance(images, list):
|
||||
if len(images) == 2 and all((_is_valid_image(image)) for image in images):
|
||||
return images
|
||||
if all(
|
||||
isinstance(image_pair, list)
|
||||
and len(image_pair) == 2
|
||||
and all(_is_valid_image(image) for image in image_pair)
|
||||
for image_pair in images
|
||||
):
|
||||
return [image for image_pair in images for image in image_pair]
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
class SuperGlueImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a SuperGlue image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
|
||||
by `do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
|
||||
Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
|
||||
`True`. Can be overriden by `size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_grayscale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_grayscale: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 480, "width": 640}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_grayscale = do_grayscale
|
||||
|
||||
# Copied from transformers.models.superpoint.image_processing_superpoint.SuperPointImageProcessor.resize
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Resize an image.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the output image. If not provided, it will be inferred from the input
|
||||
image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
return resize(
|
||||
image,
|
||||
size=(size["height"], size["width"]),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_grayscale: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with
|
||||
pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set
|
||||
`do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
|
||||
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
|
||||
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
|
||||
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
|
||||
Whether to convert the image to grayscale.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
# Validate and convert the input images into a flattened list of images for all subsequent processing steps.
|
||||
images = validate_and_format_image_pairs(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
)
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
all_images = []
|
||||
for image in images:
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_grayscale:
|
||||
image = convert_to_grayscale(image, input_data_format=input_data_format)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
all_images.append(image)
|
||||
|
||||
# Convert back the flattened list of images into a list of pairs of images.
|
||||
image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)]
|
||||
|
||||
data = {"pixel_values": image_pairs}
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def post_process_keypoint_matching(
|
||||
self,
|
||||
outputs: "KeypointMatchingOutput",
|
||||
target_sizes: Union[TensorType, List[Tuple]],
|
||||
threshold: float = 0.0,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
|
||||
with coordinates absolute to the original image sizes.
|
||||
Args:
|
||||
outputs ([`KeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
|
||||
target size `(height, width)` of each image in the batch. This must be the original image size (before
|
||||
any processing).
|
||||
threshold (`float`, *optional*, defaults to 0.0):
|
||||
Threshold to filter out the matches with low scores.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
|
||||
of the pair, the matching scores and the matching indices.
|
||||
"""
|
||||
if outputs.mask.shape[0] != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||
if not all(len(target_size) == 2 for target_size in target_sizes):
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
if isinstance(target_sizes, List):
|
||||
image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
|
||||
else:
|
||||
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||
)
|
||||
image_pair_sizes = target_sizes
|
||||
|
||||
keypoints = outputs.keypoints.clone()
|
||||
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
|
||||
keypoints = keypoints.to(torch.int32)
|
||||
|
||||
results = []
|
||||
for mask_pair, keypoints_pair, matches, scores in zip(
|
||||
outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0]
|
||||
):
|
||||
mask0 = mask_pair[0] > 0
|
||||
mask1 = mask_pair[1] > 0
|
||||
keypoints0 = keypoints_pair[0][mask0]
|
||||
keypoints1 = keypoints_pair[1][mask1]
|
||||
matches0 = matches[mask0]
|
||||
scores0 = scores[mask0]
|
||||
|
||||
# Filter out matches with low scores
|
||||
valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1)
|
||||
|
||||
matched_keypoints0 = keypoints0[valid_matches]
|
||||
matched_keypoints1 = keypoints1[matches0[valid_matches]]
|
||||
matching_scores = scores0[valid_matches]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"keypoints0": matched_keypoints0,
|
||||
"keypoints1": matched_keypoints1,
|
||||
"matching_scores": matching_scores,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
__all__ = ["SuperGlueImageProcessor"]
|
866
src/transformers/models/superglue/modeling_superglue.py
Normal file
866
src/transformers/models/superglue/modeling_superglue.py
Normal file
@ -0,0 +1,866 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch SuperGlue model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import PreTrainedModel, add_start_docstrings
|
||||
from transformers.models.superglue.configuration_superglue import SuperGlueConfig
|
||||
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging
|
||||
from ..auto import AutoModelForKeypointDetection
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC_ = "SuperGlueConfig"
|
||||
_CHECKPOINT_FOR_DOC_ = "magic-leap-community/superglue_indoor"
|
||||
|
||||
|
||||
def concat_pairs(tensor_tuple0: Tuple[torch.Tensor], tensor_tuple1: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
Concatenate two tuples of tensors pairwise
|
||||
|
||||
Args:
|
||||
tensor_tuple0 (`Tuple[torch.Tensor]`):
|
||||
Tuple of tensors.
|
||||
tensor_tuple1 (`Tuple[torch.Tensor]`):
|
||||
Tuple of tensors.
|
||||
|
||||
Returns:
|
||||
(`Tuple[torch.Tensor]`): Tuple of concatenated tensors.
|
||||
"""
|
||||
return tuple([torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1)])
|
||||
|
||||
|
||||
def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
Normalize keypoints locations based on image image_shape
|
||||
|
||||
Args:
|
||||
keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||
Keypoints locations in (x, y) format.
|
||||
height (`int`):
|
||||
Image height.
|
||||
width (`int`):
|
||||
Image width.
|
||||
|
||||
Returns:
|
||||
Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
|
||||
"""
|
||||
size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
|
||||
center = size / 2
|
||||
scaling = size.max(1, keepdim=True).values * 0.7
|
||||
return (keypoints - center[:, None, :]) / scaling[:, None, :]
|
||||
|
||||
|
||||
def log_sinkhorn_iterations(
|
||||
log_cost_matrix: torch.Tensor,
|
||||
log_source_distribution: torch.Tensor,
|
||||
log_target_distribution: torch.Tensor,
|
||||
num_iterations: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform Sinkhorn Normalization in Log-space for stability
|
||||
|
||||
Args:
|
||||
log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
|
||||
Logarithm of the cost matrix.
|
||||
log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`):
|
||||
Logarithm of the source distribution.
|
||||
log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`):
|
||||
Logarithm of the target distribution.
|
||||
|
||||
Returns:
|
||||
log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal
|
||||
transport matrix.
|
||||
"""
|
||||
log_u_scaling = torch.zeros_like(log_source_distribution)
|
||||
log_v_scaling = torch.zeros_like(log_target_distribution)
|
||||
for _ in range(num_iterations):
|
||||
log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2)
|
||||
log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1)
|
||||
return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1)
|
||||
|
||||
|
||||
def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor:
|
||||
"""
|
||||
Perform Differentiable Optimal Transport in Log-space for stability
|
||||
|
||||
Args:
|
||||
scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
|
||||
Cost matrix.
|
||||
reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`):
|
||||
Regularization parameter.
|
||||
iterations: (`int`):
|
||||
Number of Sinkhorn iterations.
|
||||
|
||||
Returns:
|
||||
log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the
|
||||
optimal transport matrix.
|
||||
"""
|
||||
batch_size, num_rows, num_columns = scores.shape
|
||||
one_tensor = scores.new_tensor(1)
|
||||
num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores)
|
||||
|
||||
source_reg_param = reg_param.expand(batch_size, num_rows, 1)
|
||||
target_reg_param = reg_param.expand(batch_size, 1, num_columns)
|
||||
reg_param = reg_param.expand(batch_size, 1, 1)
|
||||
|
||||
couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1)
|
||||
|
||||
log_normalization = -(num_rows_tensor + num_columns_tensor).log()
|
||||
log_source_distribution = torch.cat(
|
||||
[log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization]
|
||||
)
|
||||
log_target_distribution = torch.cat(
|
||||
[log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization]
|
||||
)
|
||||
log_source_distribution, log_target_distribution = (
|
||||
log_source_distribution[None].expand(batch_size, -1),
|
||||
log_target_distribution[None].expand(batch_size, -1),
|
||||
)
|
||||
|
||||
log_optimal_transport_matrix = log_sinkhorn_iterations(
|
||||
couplings, log_source_distribution, log_target_distribution, num_iterations=iterations
|
||||
)
|
||||
log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N
|
||||
return log_optimal_transport_matrix
|
||||
|
||||
|
||||
def arange_like(x, dim: int) -> torch.Tensor:
|
||||
return x.new_ones(x.shape[dim]).cumsum(0) - 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeypointMatchingOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number
|
||||
of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
|
||||
images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
|
||||
used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
|
||||
information.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
||||
Loss computed during training.
|
||||
mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
|
||||
Mask indicating which values in matches and matching_scores are keypoint matching information.
|
||||
matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
||||
Index of keypoint matched in the other image.
|
||||
matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
||||
Scores of predicted matches.
|
||||
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||
Absolute (x, y) coordinates of predicted keypoints in a given image.
|
||||
hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
|
||||
num_keypoints)`, returned when `output_hidden_states=True` is passed or when
|
||||
`config.output_hidden_states=True`)
|
||||
attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
|
||||
num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
matches: Optional[torch.FloatTensor] = None
|
||||
matching_scores: Optional[torch.FloatTensor] = None
|
||||
keypoints: Optional[torch.FloatTensor] = None
|
||||
mask: Optional[torch.IntTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class SuperGlueMultiLayerPerceptron(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_channels, out_channels)
|
||||
self.batch_norm = nn.BatchNorm1d(out_channels)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.linear(hidden_state)
|
||||
hidden_state = hidden_state.transpose(-1, -2)
|
||||
hidden_state = self.batch_norm(hidden_state)
|
||||
hidden_state = hidden_state.transpose(-1, -2)
|
||||
hidden_state = self.activation(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class SuperGlueKeypointEncoder(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig) -> None:
|
||||
super().__init__()
|
||||
layer_sizes = config.keypoint_encoder_sizes
|
||||
hidden_size = config.hidden_size
|
||||
# 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint
|
||||
encoder_channels = [3] + layer_sizes + [hidden_size]
|
||||
|
||||
layers = [
|
||||
SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i])
|
||||
for i in range(1, len(encoder_channels) - 1)
|
||||
]
|
||||
layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1]))
|
||||
self.encoder = nn.ModuleList(layers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
keypoints: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
||||
scores = scores.unsqueeze(2)
|
||||
hidden_state = torch.cat([keypoints, scores], dim=2)
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for layer in self.encoder:
|
||||
hidden_state = layer(hidden_state)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
return hidden_state, all_hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue
|
||||
class SuperGlueSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class SuperGlueSelfOutput(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
SUPERGLUE_SELF_ATTENTION_CLASSES = {
|
||||
"eager": SuperGlueSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE
|
||||
class SuperGlueAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = SuperGlueSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class SuperGlueAttentionalPropagation(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig) -> None:
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.attention = SuperGlueAttention(config)
|
||||
mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size]
|
||||
layers = [
|
||||
SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i])
|
||||
for i in range(1, len(mlp_channels) - 1)
|
||||
]
|
||||
layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1]))
|
||||
self.mlp = nn.ModuleList(layers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
descriptors: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]:
|
||||
attention_outputs = self.attention(
|
||||
descriptors,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
output = attention_outputs[0]
|
||||
attention = attention_outputs[1:]
|
||||
|
||||
hidden_state = torch.cat([descriptors, output], dim=2)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for layer in self.mlp:
|
||||
hidden_state = layer(hidden_state)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
return hidden_state, all_hidden_states, attention
|
||||
|
||||
|
||||
class SuperGlueAttentionalGNN(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layers_types = config.gnn_layers_types
|
||||
self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
descriptors: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple], Optional[Tuple]]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
batch_size, num_keypoints, _ = descriptors.shape
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (descriptors,)
|
||||
|
||||
for gnn_layer, layer_type in zip(self.layers, self.layers_types):
|
||||
encoder_hidden_states = None
|
||||
encoder_attention_mask = None
|
||||
if layer_type == "cross":
|
||||
encoder_hidden_states = (
|
||||
descriptors.reshape(-1, 2, num_keypoints, self.hidden_size)
|
||||
.flip(1)
|
||||
.reshape(batch_size, num_keypoints, self.hidden_size)
|
||||
)
|
||||
encoder_attention_mask = (
|
||||
mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
|
||||
if mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
gnn_outputs = gnn_layer(
|
||||
descriptors,
|
||||
attention_mask=mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
delta = gnn_outputs[0]
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + gnn_outputs[1]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + gnn_outputs[2]
|
||||
|
||||
descriptors = descriptors + delta
|
||||
return descriptors, all_hidden_states, all_attentions
|
||||
|
||||
|
||||
class SuperGlueFinalProjection(nn.Module):
|
||||
def __init__(self, config: SuperGlueConfig) -> None:
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||
|
||||
def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
|
||||
return self.final_proj(descriptors)
|
||||
|
||||
|
||||
class SuperGluePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = SuperGlueConfig
|
||||
base_model_prefix = "superglue"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, SuperGlueMultiLayerPerceptron):
|
||||
nn.init.constant_(module.linear.bias, 0.0)
|
||||
|
||||
|
||||
SUPERGLUE_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`SuperGlueConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
SUPERGLUE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See
|
||||
[`SuperGlueImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"SuperGlue model taking images as inputs and outputting the matching of them.",
|
||||
SUPERGLUE_START_DOCSTRING,
|
||||
)
|
||||
class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
|
||||
"""SuperGlue feature matching middle-end
|
||||
|
||||
Given two sets of keypoints and locations, we determine the
|
||||
correspondences by:
|
||||
1. Keypoint Encoding (normalization + visual feature and location fusion)
|
||||
2. Graph Neural Network with multiple self and cross-attention layers
|
||||
3. Final projection layer
|
||||
4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
|
||||
5. Thresholding matrix based on mutual exclusivity and a match_threshold
|
||||
|
||||
The correspondence ids use -1 to indicate non-matching points.
|
||||
|
||||
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
|
||||
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
|
||||
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
|
||||
"""
|
||||
|
||||
def __init__(self, config: SuperGlueConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
|
||||
|
||||
self.keypoint_encoder = SuperGlueKeypointEncoder(config)
|
||||
self.gnn = SuperGlueAttentionalGNN(config)
|
||||
self.final_projection = SuperGlueFinalProjection(config)
|
||||
|
||||
bin_score = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.register_parameter("bin_score", bin_score)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def _match_image_pair(
|
||||
self,
|
||||
keypoints: torch.Tensor,
|
||||
descriptors: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple, Tuple]:
|
||||
"""
|
||||
Perform keypoint matching between two images.
|
||||
|
||||
Args:
|
||||
keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`):
|
||||
Keypoints detected in the pair of image.
|
||||
descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`):
|
||||
Descriptors of the keypoints detected in the image pair.
|
||||
scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
|
||||
Confidence scores of the keypoints detected in the image pair.
|
||||
height (`int`): Image height.
|
||||
width (`int`): Image width.
|
||||
mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*):
|
||||
Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching
|
||||
information.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors. Default to `config.output_attentions`.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`.
|
||||
|
||||
Returns:
|
||||
matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
|
||||
For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched
|
||||
with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with.
|
||||
matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
|
||||
Scores of predicted matches for each image pair
|
||||
all_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints,
|
||||
num_channels)`.
|
||||
all_attentions (`tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints,
|
||||
num_keypoints)`.
|
||||
"""
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
if keypoints.shape[2] == 0: # no keypoints
|
||||
shape = keypoints.shape[:-1]
|
||||
return (
|
||||
keypoints.new_full(shape, -1, dtype=torch.int),
|
||||
keypoints.new_zeros(shape),
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
)
|
||||
|
||||
batch_size, _, num_keypoints, _ = keypoints.shape
|
||||
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
|
||||
keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2)
|
||||
descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size)
|
||||
scores = scores.reshape(batch_size * 2, num_keypoints)
|
||||
mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None
|
||||
|
||||
# Keypoint normalization
|
||||
keypoints = normalize_keypoints(keypoints, height, width)
|
||||
|
||||
encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states)
|
||||
|
||||
last_hidden_state = encoded_keypoints[0]
|
||||
|
||||
# Keypoint MLP encoder.
|
||||
descriptors = descriptors + last_hidden_state
|
||||
|
||||
if mask is not None:
|
||||
input_shape = descriptors.size()
|
||||
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
|
||||
else:
|
||||
extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device)
|
||||
|
||||
# Multi-layer Transformer network.
|
||||
gnn_outputs = self.gnn(
|
||||
descriptors,
|
||||
mask=extended_attention_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
descriptors = gnn_outputs[0]
|
||||
|
||||
# Final MLP projection.
|
||||
projected_descriptors = self.final_projection(descriptors)
|
||||
|
||||
# (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
|
||||
final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size)
|
||||
final_descriptors0 = final_descriptors[:, 0]
|
||||
final_descriptors1 = final_descriptors[:, 1]
|
||||
|
||||
# Compute matching descriptor distance.
|
||||
scores = final_descriptors0 @ final_descriptors1.transpose(1, 2)
|
||||
scores = scores / self.config.hidden_size**0.5
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.reshape(batch_size, 2, num_keypoints)
|
||||
mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints)
|
||||
scores = scores.masked_fill(mask0 == 0, -1e9)
|
||||
|
||||
# Run the optimal transport.
|
||||
scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)
|
||||
|
||||
# Get the matches with score above "match_threshold".
|
||||
max0 = scores[:, :-1, :-1].max(2)
|
||||
max1 = scores[:, :-1, :-1].max(1)
|
||||
indices0 = max0.indices
|
||||
indices1 = max1.indices
|
||||
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
|
||||
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
|
||||
zero = scores.new_tensor(0)
|
||||
matching_scores0 = torch.where(mutual0, max0.values.exp(), zero)
|
||||
matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero)
|
||||
matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero)
|
||||
valid0 = mutual0 & (matching_scores0 > zero)
|
||||
valid1 = mutual1 & valid0.gather(1, indices1)
|
||||
matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
|
||||
matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
|
||||
|
||||
matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1)
|
||||
matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + encoded_keypoints[1]
|
||||
all_hidden_states = all_hidden_states + gnn_outputs[1]
|
||||
all_hidden_states = all_hidden_states + (projected_descriptors,)
|
||||
all_hidden_states = tuple(
|
||||
x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states
|
||||
)
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + gnn_outputs[2]
|
||||
all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions)
|
||||
|
||||
return (
|
||||
matches,
|
||||
matching_scores,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(SUPERGLUE_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, KeypointMatchingOutput]:
|
||||
"""
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, AutoModel
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
|
||||
>>> image1 = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
|
||||
>>> image2 = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> images = [image1, image2]
|
||||
|
||||
>>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
>>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
>>> inputs = processor(images, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```"""
|
||||
loss = None
|
||||
if labels is not None:
|
||||
raise ValueError("SuperGlue is not trainable, no labels should be provided.")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
|
||||
raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
|
||||
|
||||
batch_size, _, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
|
||||
keypoint_detections = self.keypoint_detector(pixel_values)
|
||||
|
||||
keypoints, scores, descriptors, mask = keypoint_detections[:4]
|
||||
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
|
||||
scores = scores.reshape(batch_size, 2, -1).to(pixel_values)
|
||||
descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values)
|
||||
mask = mask.reshape(batch_size, 2, -1)
|
||||
|
||||
absolute_keypoints = keypoints.clone()
|
||||
absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
|
||||
absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
|
||||
|
||||
matches, matching_scores, hidden_states, attentions = self._match_image_pair(
|
||||
absolute_keypoints,
|
||||
descriptors,
|
||||
scores,
|
||||
height,
|
||||
width,
|
||||
mask=mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return KeypointMatchingOutput(
|
||||
loss=loss,
|
||||
matches=matches,
|
||||
matching_scores=matching_scores,
|
||||
keypoints=keypoints,
|
||||
mask=mask,
|
||||
hidden_states=hidden_states,
|
||||
attentions=attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]
|
@ -144,7 +144,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
preprocessor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
model_name = "superpoint"
|
||||
model_name = "magic-leap-community/superpoint"
|
||||
if push_to_hub:
|
||||
print(f"Pushing {model_name} to the hub...")
|
||||
model.push_to_hub(model_name)
|
||||
|
@ -49,8 +49,12 @@ def is_grayscale(
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
if image.shape[0] == 1:
|
||||
return True
|
||||
return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
if image.shape[-1] == 1:
|
||||
return True
|
||||
return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
|
||||
|
||||
|
||||
@ -75,6 +79,8 @@ def convert_to_grayscale(
|
||||
requires_backends(convert_to_grayscale, ["vision"])
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
if is_grayscale(image, input_data_format=input_data_format):
|
||||
return image
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
|
||||
gray_image = np.stack([gray_image] * 3, axis=0)
|
||||
@ -107,6 +113,8 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_grayscale (`bool`, *optional*, defaults to `False`):
|
||||
Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
@ -117,6 +125,7 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int] = None,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_grayscale: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -127,6 +136,7 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
self.size = size
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_grayscale = do_grayscale
|
||||
|
||||
def resize(
|
||||
self,
|
||||
@ -174,6 +184,7 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int] = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_grayscale: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@ -197,6 +208,8 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
|
||||
Whether to convert the image to grayscale.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
@ -220,6 +233,7 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
@ -264,10 +278,8 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# Checking if image is RGB or grayscale
|
||||
for i in range(len(images)):
|
||||
if not is_grayscale(images[i], input_data_format):
|
||||
images[i] = convert_to_grayscale(images[i], input_data_format=input_data_format)
|
||||
if do_grayscale:
|
||||
images = [convert_to_grayscale(image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
@ -299,7 +311,7 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||
|
||||
if isinstance(target_sizes, List):
|
||||
image_sizes = torch.tensor(target_sizes)
|
||||
image_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
|
||||
else:
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError(
|
||||
|
@ -8961,6 +8961,20 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuperGlueForKeypointMatching(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuperGluePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuperPointForKeypointDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -611,6 +611,13 @@ class SiglipImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class SuperGlueImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class SuperPointImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/superglue/__init__.py
Normal file
0
tests/models/superglue/__init__.py
Normal file
384
tests/models/superglue/test_image_processing_superglue.py
Normal file
384
tests/models/superglue/test_image_processing_superglue.py
Normal file
@ -0,0 +1,384 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import (
|
||||
ImageProcessingTestMixin,
|
||||
prepare_image_inputs,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.models.superglue.modeling_superglue import KeypointMatchingOutput
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import SuperGlueImageProcessor
|
||||
|
||||
|
||||
def random_array(size):
|
||||
return np.random.randint(255, size=size)
|
||||
|
||||
|
||||
def random_tensor(size):
|
||||
return torch.rand(size)
|
||||
|
||||
|
||||
class SuperGlueImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=6,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_grayscale=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 480, "width": 640}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_grayscale = do_grayscale
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_grayscale": self.do_grayscale,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return 2, self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False, pairs=True, batch_size=None):
|
||||
batch_size = batch_size if batch_size is not None else self.batch_size
|
||||
image_inputs = prepare_image_inputs(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
if pairs:
|
||||
image_inputs = [image_inputs[i : i + 2] for i in range(0, len(image_inputs), 2)]
|
||||
return image_inputs
|
||||
|
||||
def prepare_keypoint_matching_output(self, pixel_values):
|
||||
max_number_keypoints = 50
|
||||
batch_size = len(pixel_values)
|
||||
mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int)
|
||||
keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2))
|
||||
matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int)
|
||||
scores = torch.zeros((batch_size, 2, max_number_keypoints))
|
||||
for i in range(batch_size):
|
||||
random_number_keypoints0 = np.random.randint(10, max_number_keypoints)
|
||||
random_number_keypoints1 = np.random.randint(10, max_number_keypoints)
|
||||
random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1))
|
||||
mask[i, 0, :random_number_keypoints0] = 1
|
||||
mask[i, 1, :random_number_keypoints1] = 1
|
||||
keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2))
|
||||
keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2))
|
||||
random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches]
|
||||
random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches]
|
||||
matches[i, 0, random_matches_indices1] = random_matches_indices0
|
||||
matches[i, 1, random_matches_indices0] = random_matches_indices1
|
||||
scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,))
|
||||
scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,))
|
||||
return KeypointMatchingOutput(mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = SuperGlueImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.image_processor_tester = SuperGlueImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processing(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_grayscale"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size={"height": 42, "width": 42}
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
def test_number_and_format_of_images_in_input(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
|
||||
# Cases where the number of images and the format of lists in the input is correct
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=2)
|
||||
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
|
||||
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=2)
|
||||
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
|
||||
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=4)
|
||||
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual((2, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
|
||||
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=6)
|
||||
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual((3, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
|
||||
|
||||
# Cases where the number of images or the format of lists in the input is incorrect
|
||||
## List of 4 images
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=4)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual(ValueError, cm.exception.__class__)
|
||||
|
||||
## List of 3 images
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=3)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual(ValueError, cm.exception.__class__)
|
||||
|
||||
## List of 2 pairs and 1 image
|
||||
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=3)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual(ValueError, cm.exception.__class__)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
([random_array((3, 100, 200)), random_array((3, 100, 200))], (1, 2, 3, 480, 640)),
|
||||
([[random_array((3, 100, 200)), random_array((3, 100, 200))]], (1, 2, 3, 480, 640)),
|
||||
([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)),
|
||||
([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)),
|
||||
],
|
||||
)
|
||||
def test_valid_image_shape_in_input(self, image_input, output):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual(output, tuple(image_processed["pixel_values"].shape))
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(random_array((3, 100, 200)),),
|
||||
([random_array((3, 100, 200))],),
|
||||
(random_array((1, 3, 100, 200)),),
|
||||
([[random_array((3, 100, 200))]],),
|
||||
([[random_array((3, 100, 200))], [random_array((3, 100, 200))]],),
|
||||
([random_array((1, 3, 100, 200)), random_array((1, 3, 100, 200))],),
|
||||
(random_array((1, 1, 3, 100, 200)),),
|
||||
],
|
||||
)
|
||||
def test_invalid_image_shape_in_input(self, image_input):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
image_processor.preprocess(image_input, return_tensors="pt")
|
||||
self.assertEqual(ValueError, cm.exception.__class__)
|
||||
|
||||
def test_input_images_properly_paired(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="np")
|
||||
self.assertEqual(len(pre_processed_images["pixel_values"].shape), 5)
|
||||
self.assertEqual(pre_processed_images["pixel_values"].shape[1], 2)
|
||||
|
||||
def test_input_not_paired_images_raises_error(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(pairs=False)
|
||||
with self.assertRaises(ValueError):
|
||||
image_processor.preprocess(image_inputs[0])
|
||||
|
||||
def test_input_image_properly_converted_to_grayscale(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs)
|
||||
for image_pair in pre_processed_images["pixel_values"]:
|
||||
for image in image_pair:
|
||||
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
|
||||
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for image_pair in image_pairs:
|
||||
self.assertEqual(len(image_pair), 2)
|
||||
|
||||
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
|
||||
|
||||
# Test with 2 images
|
||||
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test with list of pairs
|
||||
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
|
||||
|
||||
# Test without paired images
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, numpify=True, pairs=False
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
|
||||
def test_call_pil(self):
|
||||
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
|
||||
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for image_pair in image_pairs:
|
||||
self.assertEqual(len(image_pair), 2)
|
||||
|
||||
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
|
||||
|
||||
# Test with 2 images
|
||||
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test with list of pairs
|
||||
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
|
||||
|
||||
# Test without paired images
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, pairs=False)
|
||||
with self.assertRaises(ValueError):
|
||||
image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
|
||||
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
for image_pair in image_pairs:
|
||||
self.assertEqual(len(image_pair), 2)
|
||||
|
||||
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
|
||||
|
||||
# Test with 2 images
|
||||
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test with list of pairs
|
||||
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
|
||||
|
||||
# Test without paired images
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, torchify=True, pairs=False
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
|
||||
def test_image_processor_with_list_of_two_images(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
image_pairs = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, numpify=True, batch_size=2, pairs=False
|
||||
)
|
||||
self.assertEqual(len(image_pairs), 2)
|
||||
self.assertTrue(isinstance(image_pairs[0], np.ndarray))
|
||||
self.assertTrue(isinstance(image_pairs[1], np.ndarray))
|
||||
|
||||
expected_batch_size = 1
|
||||
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
|
||||
|
||||
@require_torch
|
||||
def test_post_processing_keypoint_matching(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
||||
outputs = self.image_processor_tester.prepare_keypoint_matching_output(**pre_processed_images)
|
||||
|
||||
def check_post_processed_output(post_processed_output, image_pair_size):
|
||||
for post_processed_output, (image_size0, image_size1) in zip(post_processed_output, image_pair_size):
|
||||
self.assertTrue("keypoints0" in post_processed_output)
|
||||
self.assertTrue("keypoints1" in post_processed_output)
|
||||
self.assertTrue("matching_scores" in post_processed_output)
|
||||
keypoints0 = post_processed_output["keypoints0"]
|
||||
keypoints1 = post_processed_output["keypoints1"]
|
||||
all_below_image_size0 = torch.all(keypoints0[:, 0] <= image_size0[1]) and torch.all(
|
||||
keypoints0[:, 1] <= image_size0[0]
|
||||
)
|
||||
all_below_image_size1 = torch.all(keypoints1[:, 0] <= image_size1[1]) and torch.all(
|
||||
keypoints1[:, 1] <= image_size1[0]
|
||||
)
|
||||
all_above_zero0 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0)
|
||||
all_above_zero1 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0)
|
||||
self.assertTrue(all_below_image_size0)
|
||||
self.assertTrue(all_below_image_size1)
|
||||
self.assertTrue(all_above_zero0)
|
||||
self.assertTrue(all_above_zero1)
|
||||
all_scores_different_from_minus_one = torch.all(post_processed_output["matching_scores"] != -1)
|
||||
self.assertTrue(all_scores_different_from_minus_one)
|
||||
|
||||
tuple_image_sizes = [
|
||||
((image_pair[0].size[0], image_pair[0].size[1]), (image_pair[1].size[0], image_pair[1].size[1]))
|
||||
for image_pair in image_inputs
|
||||
]
|
||||
tuple_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tuple_image_sizes)
|
||||
|
||||
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
||||
|
||||
tensor_image_sizes = torch.tensor(
|
||||
[(image_pair[0].size, image_pair[1].size) for image_pair in image_inputs]
|
||||
).flip(2)
|
||||
tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes)
|
||||
|
||||
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
427
tests/models/superglue/test_modeling_superglue.py
Normal file
427
tests/models/superglue/test_modeling_superglue.py
Normal file
@ -0,0 +1,427 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.models.superglue.configuration_superglue import SuperGlueConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SuperGlueForKeypointMatching
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
|
||||
class SuperGlueModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
image_width=80,
|
||||
image_height=60,
|
||||
keypoint_detector_config=None,
|
||||
hidden_size: int = 64,
|
||||
keypoint_encoder_sizes: List[int] = [32, 64],
|
||||
gnn_layers_types: List[str] = ["self", "cross"] * 2,
|
||||
num_attention_heads: int = 4,
|
||||
sinkhorn_iterations: int = 100,
|
||||
matching_threshold: float = 0.2,
|
||||
):
|
||||
if keypoint_detector_config is None:
|
||||
keypoint_detector_config = {
|
||||
"encoder_hidden_sizes": [32, 64],
|
||||
"decoder_hidden_size": 64,
|
||||
"keypoint_decoder_dim": 65,
|
||||
"descriptor_decoder_dim": 64,
|
||||
"keypoint_threshold": 0.005,
|
||||
"max_keypoints": 256,
|
||||
"nms_radius": 4,
|
||||
"border_removal_distance": 4,
|
||||
}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
|
||||
self.keypoint_detector_config = keypoint_detector_config
|
||||
self.hidden_size = hidden_size
|
||||
self.keypoint_encoder_sizes = keypoint_encoder_sizes
|
||||
self.gnn_layers_types = gnn_layers_types
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sinkhorn_iterations = sinkhorn_iterations
|
||||
self.matching_threshold = matching_threshold
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
# SuperGlue expects a grayscale image as input
|
||||
pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width])
|
||||
config = self.get_config()
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return SuperGlueConfig(
|
||||
keypoint_detector_config=self.keypoint_detector_config,
|
||||
hidden_size=self.hidden_size,
|
||||
keypoint_encoder_sizes=self.keypoint_encoder_sizes,
|
||||
gnn_layers_types=self.gnn_layers_types,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
sinkhorn_iterations=self.sinkhorn_iterations,
|
||||
matching_threshold=self.matching_threshold,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = SuperGlueForKeypointMatching(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
maximum_num_matches = result.mask.shape[-1]
|
||||
self.parent.assertEqual(
|
||||
result.keypoints.shape,
|
||||
(self.batch_size, 2, maximum_num_matches, 2),
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.matches.shape,
|
||||
(self.batch_size, 2, maximum_num_matches),
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.matching_scores.shape,
|
||||
(self.batch_size, 2, maximum_num_matches),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class SuperGlueModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SuperGlueForKeypointMatching,) if is_torch_available() else ()
|
||||
all_generative_model_classes = () if is_torch_available() else ()
|
||||
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SuperGlueModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SuperGlueConfig, has_text_modality=False, hidden_size=64)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlueForKeypointMatching is not trainable")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperGlue does not output any loss term in the forward pass")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
maximum_num_matches = outputs.mask.shape[-1]
|
||||
|
||||
hidden_states_sizes = (
|
||||
self.model_tester.keypoint_encoder_sizes
|
||||
+ [self.model_tester.hidden_size]
|
||||
+ [self.model_tester.hidden_size, self.model_tester.hidden_size * 2]
|
||||
* len(self.model_tester.gnn_layers_types)
|
||||
+ [self.model_tester.hidden_size] * 2
|
||||
)
|
||||
|
||||
for i, hidden_states_size in enumerate(hidden_states_sizes):
|
||||
self.assertListEqual(
|
||||
list(hidden_states[i].shape[-2:]),
|
||||
[hidden_states_size, maximum_num_matches],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
def check_attention_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
attentions = outputs.attentions
|
||||
maximum_num_matches = outputs.mask.shape[-1]
|
||||
|
||||
expected_attention_shape = [
|
||||
self.model_tester.num_attention_heads,
|
||||
maximum_num_matches,
|
||||
maximum_num_matches,
|
||||
]
|
||||
|
||||
for i, attention in enumerate(attentions):
|
||||
self.assertListEqual(
|
||||
list(attention.shape[-3:]),
|
||||
expected_attention_shape,
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
check_attention_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
|
||||
check_attention_output(inputs_dict, config, model_class)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
from_pretrained_ids = ["magic-leap-community/superglue_indoor", "magic-leap-community/superglue_outdoor"]
|
||||
for model_name in from_pretrained_ids:
|
||||
model = SuperGlueForKeypointMatching.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_forward_labels_should_be_none(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
# Provide an arbitrary sized Tensor as labels to model inputs
|
||||
model_inputs["labels"] = torch.rand((128, 128))
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model(**model_inputs)
|
||||
self.assertEqual(ValueError, cm.exception.__class__)
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
"""
|
||||
Overwriting ModelTesterMixin.test_batching_equivalence since SuperGlue returns `matching_scores` tensors full of
|
||||
zeros which causes the test to fail, because cosine_similarity of two zero tensors is 0.
|
||||
Discussed here : https://github.com/huggingface/transformers/pull/29886#issuecomment-2481539481
|
||||
"""
|
||||
|
||||
def recursive_check(batched_object, single_row_object, model_name, key):
|
||||
if isinstance(batched_object, (list, tuple)):
|
||||
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
elif isinstance(batched_object, dict):
|
||||
for batched_object_value, single_row_object_value in zip(
|
||||
batched_object.values(), single_row_object.values()
|
||||
):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
||||
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
||||
return
|
||||
elif batched_object.dim() == 0:
|
||||
return
|
||||
else:
|
||||
# indexing the first element does not always work
|
||||
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
|
||||
slice_ids = [slice(0, index) for index in single_row_object.shape]
|
||||
batched_row = batched_object[slice_ids]
|
||||
self.assertFalse(
|
||||
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertTrue(
|
||||
(equivalence(batched_row, single_row_object)) <= 1e-03,
|
||||
msg=(
|
||||
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
||||
f"Difference={equivalence(batched_row, single_row_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
def equivalence(tensor1, tensor2):
|
||||
return torch.max(torch.abs(tensor1 - tensor2))
|
||||
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
|
||||
model_name = model_class.__name__
|
||||
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
batch_size = self.model_tester.batch_size
|
||||
single_row_input = {}
|
||||
for key, value in batched_input_prepared.items():
|
||||
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
|
||||
# e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
|
||||
single_batch_shape = value.shape[0] // batch_size
|
||||
single_row_input[key] = value[:single_batch_shape]
|
||||
else:
|
||||
single_row_input[key] = value
|
||||
|
||||
with torch.no_grad():
|
||||
model_batched_output = model(**batched_input_prepared)
|
||||
model_row_output = model(**single_row_input)
|
||||
|
||||
if isinstance(model_batched_output, torch.Tensor):
|
||||
model_batched_output = {"model_output": model_batched_output}
|
||||
model_row_output = {"model_output": model_row_output}
|
||||
|
||||
for key in model_batched_output:
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
|
||||
def prepare_imgs():
|
||||
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
|
||||
image1 = dataset[0]["image"]
|
||||
image2 = dataset[1]["image"]
|
||||
image3 = dataset[2]["image"]
|
||||
return [[image1, image2], [image3, image2]]
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class SuperGlueModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return (
|
||||
AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model = SuperGlueForKeypointMatching.from_pretrained("magic-leap-community/superglue_outdoor").to(torch_device)
|
||||
preprocessor = self.default_image_processor
|
||||
images = prepare_imgs()
|
||||
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||
|
||||
predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item()
|
||||
predicted_matches_values = outputs.matches[0, 0, :30]
|
||||
predicted_matching_scores_values = outputs.matching_scores[0, 0, :20]
|
||||
|
||||
expected_number_of_matches = 282
|
||||
expected_matches_values = torch.tensor([125,630,137,138,136,143,135,-1,-1,153,
|
||||
154,156,117,160,-1,149,147,152,168,-1,
|
||||
165,182,-1,190,187,188,189,112,-1,193],
|
||||
device=predicted_matches_values.device) # fmt:skip
|
||||
expected_matching_scores_values = torch.tensor([0.9899,0.0033,0.9897,0.9889,0.9879,0.7464,0.7109,0.0,0.0,0.9841,
|
||||
0.9889,0.9639,0.0114,0.9559,0.0,0.9735,0.8018,0.5190,0.9157,0.0],
|
||||
device=predicted_matches_values.device) # fmt:skip
|
||||
|
||||
"""
|
||||
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
|
||||
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
|
||||
specific test example). The consequence of having different number of keypoints is that the number of matches
|
||||
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
|
||||
match. The matching scores will also be different, as the keypoints are different. The checks here are less
|
||||
strict to account for these inconsistencies.
|
||||
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
|
||||
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
|
||||
|
||||
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
|
||||
Such CUDA inconsistencies can be found
|
||||
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
|
||||
"""
|
||||
|
||||
self.assertTrue(abs(predicted_number_of_matches - expected_number_of_matches) < 4)
|
||||
self.assertTrue(
|
||||
torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4
|
||||
)
|
||||
self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4)
|
@ -44,6 +44,7 @@ class SuperPointImageProcessingTester:
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_grayscale=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 480, "width": 640}
|
||||
self.parent = parent
|
||||
@ -54,11 +55,13 @@ class SuperPointImageProcessingTester:
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_grayscale = do_grayscale
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_grayscale": self.do_grayscale,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
@ -112,6 +115,7 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_grayscale"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user