mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 14:20:04 +06:00

* Initial commit with template code generated by transformers-cli
* Multiple additions to SuperGlue implementation :
- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass
* Few changes for SuperGlue
* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference
* Reverted unintentional change
* Multiple changes :
- Added SuperGlue to a bunch of places
- Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
- Added testing images
* Moved things in init files
* Added docs (to be finished depending on the final implementation)
* Added necessary imports and some doc
* Removed unnecessary import
* Fixed make fix-copies bug and ran it
* Deleted SuperGlueModel
Fixed convert script
* Added SuperGlueImageProcessor
* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences
* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances
* Added initial tests for SuperGlueImageProcessor
* Added AutoModelForImageMatching in missing places and tests
* Fixed keypoint_detector_output instructions
* Fix style
* Adapted to latest main changes
* Added integration test
* Fixed bugs to pass tests
* Added keypoints returned by keypoint detector in the output of SuperGlue
* Added doc to SuperGlue
* SuperGlue returning all attention and hidden states for a fixed number of keypoints
* Make style
* Changed SuperGlueImageProcessor tests
* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly
This reverts commit 5b3b669c
* Added back hidden_states and attentions masked outputs with tests
* Renamed ImageMatching occurences into KeypointMatching
* Changed SuperGlueImageProcessor to raise error when batch_size is not even
* Added docs and clarity to hidden state and attention grouping function
* Fixed some code and done refactoring
* Fixed typo in SuperPoint output doc
* Fixed some of the formatting and variable naming problems
* Removed useless function call
* Removed AutoModelForKeypointMatching
* Fixed SuperGlueImageProcessor to only accept paris of images
* Added more fixes to SuperGlueImageProcessor
* Simplified the batching of attention and hidden states
* Simplified stack functions
* Moved attention instructions into class
* Removed unused do_batch_norm argument
* Moved weight initialization to the proper place
* Replaced deepcopy for instantiation
* Fixed small bug
* Changed from stevenbucaille to magic-leap repo
* Renamed London Bridge images to Tower Bridge
* Fixed formatting
* Renamed remaining "london" to "tower"
* Apply suggestions from code review
Small changes in the docs
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Added AutoModelForKeypointMatching
* Changed images used in example
* Several changes to image_processing_superglue and style
* Fixed resample type hint
* Changed SuperGlueImageProcessor and added test case for list of 2 images
* Changed list_of_tuples implementation
* Fix in dummy objects
* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring
* Added missing docstring
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Moved forward block at bottom
* Added docstring to forward method
* Added docstring to match_image_pair method
* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature
* Removed AutoModelForKeypointMatching
* Removed image fixtures and added load_dataset
* Added padding of images in SuperGlueImageProcessor
* Cleaned up convert_superglue_to_hf script
* Added missing docs and fixed unused argument
* Fixed SuperGlueImageProcessor tests
* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape
* Added SuperGlueForKeypointMatching back to modeling_auto
* Fixed image processor padding test
* Changed SuperGlue docs
* changes:
- Abstraction to batch, concat and stack of inconsistent tensors
- Changed conv1d's to linears to match standard attention implementations
- Renamed all tensors to be tensor0 and not tensor_0 and be consistent
- Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
- Various changes in docs, etc
* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code
* Formatting changes
* Reverted conv1d to linear conversion because of numerical differences
* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib
* fix: removed unnecessary test
* chore: removed commented code and added back hidden states transpositions
* chore: changed from "inconsistent" to "ragged" function names as suggested
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: applied suggestions
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: updated to display matched output
* chore: applied suggestion for check_image_pairs_input function
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function
* tests: simplified tests for image input format and shapes
* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly
* feat: several changes to address comments
Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts
Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()
Updated test to reflect changes in SuperGlue model
* refactor: changed validate_and_format_image_pairs function with clarity
* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class
* fix: fixed forgotten init weight change from last commit
* fix: fixed rebase mistake
* fix: removed leftover commented code
* fix: added typehint and changed some of arguments default values
* fix: fixed attribute default values for SuperGlueConfig
* feat: added SuperGlueImageProcessor post process keypoint matching method with tests
* fix: fixed SuperGlue attention and hidden state tuples aggregation
* chore: fixed mask optionality and reordered tensor reshapes to be cleaner
* chore: fixed docs and error message returned in validate_and_format_image_pairs function
* fix: fixed returned keypoints to be the ones that SuperPoint returns
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)
* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement
* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules
* WIP: implement Attention from an existing class (like BERT)
* docs: Changed docs to include more appealing matching plot
* WIP: Implement Attention
* chore: minor typehint change
* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation
* Revert "Fixed typo in SuperPoint output doc"
This reverts commit 2120390e82
.
* chore: added comments in SuperGlueImageProcessor
* chore: changed SuperGlue organization HF repo to magic-leap-community
* [run-slow] refactor: small change in layer instantiation
* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community
* [run-slow] chore: make style
* chore: update image matching fixture dataset HF repository
* [run-slow] superglue
* tests: overwriting test_batching_equivalence
* [run-slow] superglue
* tests: changed test to cope with value changing depending on cuda version
* [run-slow] superglue
* tests: changed matching_threshold value
* [run-slow] superglue
* [run-slow] superglue
* tests: changed tests for integration
* [run-slow] superglue
* fix: Changed tensor view and permutations to match original implementation results
* fix: updated convert script and integration test to include last change in model
* fix: increase tolerance for CUDA variances
* Apply suggestions from code review
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
* [run-slow] superglue
* chore: removed blank whitespaces
* [run-slow] superglue
* Revert SuperPoint image processor accident changes
* [run-slow] superglue
* refactor: reverted copy from BERT class
* tests: lower the tolerance in integration tests for SuperGlue
* [run-slow] superglue
* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors
* [run-slow] superglue
* fix: fixed imports in SuperGlue files
* chore: changed do_grayscale SuperGlueImageProcessing default value to True
* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor
* fix: set matching_threshold default value to 0.0 instead of 0.2
* feat: added matching_threshold to post_process_keypoint_matching method
* docs: update superglue.md to include matching_threshold parameter
* docs: updated SuperGlueConfig docstring for matching_threshold default value
* refactor: removed unnecessary parameters in SuperGlueConfig
* fix: changed from matching_threshold to threshold
* fix: re-revert changes to make SuperGlue attention classes copies of BERT
* [run-slow] superglue
* fix: added missing device argument in post_processing method
* [run-slow] superglue
* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)
* fix: add device to image_sizes tensor instantiation
* tests: added checks on do_grayscale test
* chore: reordered and added Optional typehint to KeypointMatchingOutput
* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods
* LightGluePR suggestions:
- use batching to get keypoint detection
* refactor: processing images done in 1 for loop instead of 4
* fix: use @ instead of torch.einsum for scores computation
* style: added #fmt skip to long tensor values
* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones
* refactor: prepare_imgs
* refactor: simplified `validate_and_format_image_pairs`
* docs: fixed doc
---------
Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
138 lines
6.1 KiB
Markdown
138 lines
6.1 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the MIT License; you may not use this file except in compliance with
|
|
the License.
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
|
|
-->
|
|
|
|
# SuperGlue
|
|
|
|
## Overview
|
|
|
|
The SuperGlue model was proposed in [SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763) by Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich.
|
|
|
|
This model consists of matching two sets of interest points detected in an image. Paired with the
|
|
[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and
|
|
estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc.
|
|
|
|
The abstract from the paper is the following:
|
|
|
|
*This paper introduces SuperGlue, a neural network that matches two sets of local features by jointly finding correspondences
|
|
and rejecting non-matchable points. Assignments are estimated by solving a differentiable optimal transport problem, whose costs
|
|
are predicted by a graph neural network. We introduce a flexible context aggregation mechanism based on attention, enabling
|
|
SuperGlue to reason about the underlying 3D scene and feature assignments jointly. Compared to traditional, hand-designed heuristics,
|
|
our technique learns priors over geometric transformations and regularities of the 3D world through end-to-end training from image
|
|
pairs. SuperGlue outperforms other learned approaches and achieves state-of-the-art results on the task of pose estimation in
|
|
challenging real-world indoor and outdoor environments. The proposed method performs matching in real-time on a modern GPU and
|
|
can be readily integrated into modern SfM or SLAM systems. The code and trained weights are publicly available at this [URL](https://github.com/magicleap/SuperGluePretrainedNetwork).*
|
|
|
|
## How to use
|
|
|
|
Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched.
|
|
The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding
|
|
matching scores.
|
|
```python
|
|
from transformers import AutoImageProcessor, AutoModel
|
|
import torch
|
|
from PIL import Image
|
|
import requests
|
|
|
|
url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
|
|
image1 = Image.open(requests.get(url_image1, stream=True).raw)
|
|
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
|
|
image_2 = Image.open(requests.get(url_image2, stream=True).raw)
|
|
|
|
images = [image1, image2]
|
|
|
|
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
|
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
|
|
|
|
inputs = processor(images, return_tensors="pt")
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
```
|
|
|
|
You can use the `post_process_keypoint_matching` method from the `SuperGlueImageProcessor` to get the keypoints and matches in a more readable format:
|
|
|
|
```python
|
|
image_sizes = [[(image.height, image.width) for image in images]]
|
|
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
|
for i, output in enumerate(outputs):
|
|
print("For the image pair", i)
|
|
for keypoint0, keypoint1, matching_score in zip(
|
|
output["keypoints0"], output["keypoints1"], output["matching_scores"]
|
|
):
|
|
print(
|
|
f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
|
|
)
|
|
|
|
```
|
|
|
|
From the outputs, you can visualize the matches between the two images using the following code:
|
|
```python
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
# Create side by side image
|
|
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
|
|
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
|
|
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
|
|
plt.imshow(merged_image)
|
|
plt.axis("off")
|
|
|
|
# Retrieve the keypoints and matches
|
|
output = outputs[0]
|
|
keypoints0 = output["keypoints0"]
|
|
keypoints1 = output["keypoints1"]
|
|
matching_scores = output["matching_scores"]
|
|
keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy()
|
|
keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy()
|
|
|
|
# Plot the matches
|
|
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
|
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores
|
|
):
|
|
plt.plot(
|
|
[keypoint0_x, keypoint1_x + image1.width],
|
|
[keypoint0_y, keypoint1_y],
|
|
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
|
alpha=0.9,
|
|
linewidth=0.5,
|
|
)
|
|
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
|
|
plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2)
|
|
|
|
# Save the plot
|
|
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
```
|
|
|
|

|
|
|
|
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
|
The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork).
|
|
|
|
## SuperGlueConfig
|
|
|
|
[[autodoc]] SuperGlueConfig
|
|
|
|
## SuperGlueImageProcessor
|
|
|
|
[[autodoc]] SuperGlueImageProcessor
|
|
|
|
- preprocess
|
|
|
|
## SuperGlueForKeypointMatching
|
|
|
|
[[autodoc]] SuperGlueForKeypointMatching
|
|
|
|
- forward
|
|
- post_process_keypoint_matching |