mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* Initial commit with template code generated by transformers-cli
* Multiple additions to SuperGlue implementation :
- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass
* Few changes for SuperGlue
* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference
* Reverted unintentional change
* Multiple changes :
- Added SuperGlue to a bunch of places
- Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
- Added testing images
* Moved things in init files
* Added docs (to be finished depending on the final implementation)
* Added necessary imports and some doc
* Removed unnecessary import
* Fixed make fix-copies bug and ran it
* Deleted SuperGlueModel
Fixed convert script
* Added SuperGlueImageProcessor
* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences
* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances
* Added initial tests for SuperGlueImageProcessor
* Added AutoModelForImageMatching in missing places and tests
* Fixed keypoint_detector_output instructions
* Fix style
* Adapted to latest main changes
* Added integration test
* Fixed bugs to pass tests
* Added keypoints returned by keypoint detector in the output of SuperGlue
* Added doc to SuperGlue
* SuperGlue returning all attention and hidden states for a fixed number of keypoints
* Make style
* Changed SuperGlueImageProcessor tests
* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly
This reverts commit 5b3b669c
* Added back hidden_states and attentions masked outputs with tests
* Renamed ImageMatching occurences into KeypointMatching
* Changed SuperGlueImageProcessor to raise error when batch_size is not even
* Added docs and clarity to hidden state and attention grouping function
* Fixed some code and done refactoring
* Fixed typo in SuperPoint output doc
* Fixed some of the formatting and variable naming problems
* Removed useless function call
* Removed AutoModelForKeypointMatching
* Fixed SuperGlueImageProcessor to only accept paris of images
* Added more fixes to SuperGlueImageProcessor
* Simplified the batching of attention and hidden states
* Simplified stack functions
* Moved attention instructions into class
* Removed unused do_batch_norm argument
* Moved weight initialization to the proper place
* Replaced deepcopy for instantiation
* Fixed small bug
* Changed from stevenbucaille to magic-leap repo
* Renamed London Bridge images to Tower Bridge
* Fixed formatting
* Renamed remaining "london" to "tower"
* Apply suggestions from code review
Small changes in the docs
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Added AutoModelForKeypointMatching
* Changed images used in example
* Several changes to image_processing_superglue and style
* Fixed resample type hint
* Changed SuperGlueImageProcessor and added test case for list of 2 images
* Changed list_of_tuples implementation
* Fix in dummy objects
* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring
* Added missing docstring
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Moved forward block at bottom
* Added docstring to forward method
* Added docstring to match_image_pair method
* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature
* Removed AutoModelForKeypointMatching
* Removed image fixtures and added load_dataset
* Added padding of images in SuperGlueImageProcessor
* Cleaned up convert_superglue_to_hf script
* Added missing docs and fixed unused argument
* Fixed SuperGlueImageProcessor tests
* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape
* Added SuperGlueForKeypointMatching back to modeling_auto
* Fixed image processor padding test
* Changed SuperGlue docs
* changes:
- Abstraction to batch, concat and stack of inconsistent tensors
- Changed conv1d's to linears to match standard attention implementations
- Renamed all tensors to be tensor0 and not tensor_0 and be consistent
- Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
- Various changes in docs, etc
* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code
* Formatting changes
* Reverted conv1d to linear conversion because of numerical differences
* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib
* fix: removed unnecessary test
* chore: removed commented code and added back hidden states transpositions
* chore: changed from "inconsistent" to "ragged" function names as suggested
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: applied suggestions
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* docs: updated to display matched output
* chore: applied suggestion for check_image_pairs_input function
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function
* tests: simplified tests for image input format and shapes
* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly
* feat: several changes to address comments
Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts
Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()
Updated test to reflect changes in SuperGlue model
* refactor: changed validate_and_format_image_pairs function with clarity
* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class
* fix: fixed forgotten init weight change from last commit
* fix: fixed rebase mistake
* fix: removed leftover commented code
* fix: added typehint and changed some of arguments default values
* fix: fixed attribute default values for SuperGlueConfig
* feat: added SuperGlueImageProcessor post process keypoint matching method with tests
* fix: fixed SuperGlue attention and hidden state tuples aggregation
* chore: fixed mask optionality and reordered tensor reshapes to be cleaner
* chore: fixed docs and error message returned in validate_and_format_image_pairs function
* fix: fixed returned keypoints to be the ones that SuperPoint returns
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue
* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)
* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement
* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules
* WIP: implement Attention from an existing class (like BERT)
* docs: Changed docs to include more appealing matching plot
* WIP: Implement Attention
* chore: minor typehint change
* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation
* Revert "Fixed typo in SuperPoint output doc"
This reverts commit 2120390e82
.
* chore: added comments in SuperGlueImageProcessor
* chore: changed SuperGlue organization HF repo to magic-leap-community
* [run-slow] refactor: small change in layer instantiation
* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community
* [run-slow] chore: make style
* chore: update image matching fixture dataset HF repository
* [run-slow] superglue
* tests: overwriting test_batching_equivalence
* [run-slow] superglue
* tests: changed test to cope with value changing depending on cuda version
* [run-slow] superglue
* tests: changed matching_threshold value
* [run-slow] superglue
* [run-slow] superglue
* tests: changed tests for integration
* [run-slow] superglue
* fix: Changed tensor view and permutations to match original implementation results
* fix: updated convert script and integration test to include last change in model
* fix: increase tolerance for CUDA variances
* Apply suggestions from code review
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
* [run-slow] superglue
* chore: removed blank whitespaces
* [run-slow] superglue
* Revert SuperPoint image processor accident changes
* [run-slow] superglue
* refactor: reverted copy from BERT class
* tests: lower the tolerance in integration tests for SuperGlue
* [run-slow] superglue
* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors
* [run-slow] superglue
* fix: fixed imports in SuperGlue files
* chore: changed do_grayscale SuperGlueImageProcessing default value to True
* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor
* fix: set matching_threshold default value to 0.0 instead of 0.2
* feat: added matching_threshold to post_process_keypoint_matching method
* docs: update superglue.md to include matching_threshold parameter
* docs: updated SuperGlueConfig docstring for matching_threshold default value
* refactor: removed unnecessary parameters in SuperGlueConfig
* fix: changed from matching_threshold to threshold
* fix: re-revert changes to make SuperGlue attention classes copies of BERT
* [run-slow] superglue
* fix: added missing device argument in post_processing method
* [run-slow] superglue
* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)
* fix: add device to image_sizes tensor instantiation
* tests: added checks on do_grayscale test
* chore: reordered and added Optional typehint to KeypointMatchingOutput
* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods
* LightGluePR suggestions:
- use batching to get keypoint detection
* refactor: processing images done in 1 for loop instead of 4
* fix: use @ instead of torch.einsum for scores computation
* style: added #fmt skip to long tensor values
* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones
* refactor: prepare_imgs
* refactor: simplified `validate_and_format_image_pairs`
* docs: fixed doc
---------
Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
169 lines
7.2 KiB
Python
169 lines
7.2 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers.testing_utils import require_torch, require_vision
|
|
from transformers.utils import is_torch_available, is_vision_available
|
|
|
|
from ...test_image_processing_common import (
|
|
ImageProcessingTestMixin,
|
|
prepare_image_inputs,
|
|
)
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput
|
|
|
|
if is_vision_available():
|
|
from transformers import SuperPointImageProcessor
|
|
|
|
|
|
class SuperPointImageProcessingTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=7,
|
|
num_channels=3,
|
|
image_size=18,
|
|
min_resolution=30,
|
|
max_resolution=400,
|
|
do_resize=True,
|
|
size=None,
|
|
do_grayscale=True,
|
|
):
|
|
size = size if size is not None else {"height": 480, "width": 640}
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.num_channels = num_channels
|
|
self.image_size = image_size
|
|
self.min_resolution = min_resolution
|
|
self.max_resolution = max_resolution
|
|
self.do_resize = do_resize
|
|
self.size = size
|
|
self.do_grayscale = do_grayscale
|
|
|
|
def prepare_image_processor_dict(self):
|
|
return {
|
|
"do_resize": self.do_resize,
|
|
"size": self.size,
|
|
"do_grayscale": self.do_grayscale,
|
|
}
|
|
|
|
def expected_output_image_shape(self, images):
|
|
return self.num_channels, self.size["height"], self.size["width"]
|
|
|
|
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
|
return prepare_image_inputs(
|
|
batch_size=self.batch_size,
|
|
num_channels=self.num_channels,
|
|
min_resolution=self.min_resolution,
|
|
max_resolution=self.max_resolution,
|
|
equal_resolution=equal_resolution,
|
|
numpify=numpify,
|
|
torchify=torchify,
|
|
)
|
|
|
|
def prepare_keypoint_detection_output(self, pixel_values):
|
|
max_number_keypoints = 50
|
|
batch_size = len(pixel_values)
|
|
mask = torch.zeros((batch_size, max_number_keypoints))
|
|
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
|
|
scores = torch.zeros((batch_size, max_number_keypoints))
|
|
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
|
|
for i in range(batch_size):
|
|
random_number_keypoints = np.random.randint(0, max_number_keypoints)
|
|
mask[i, :random_number_keypoints] = 1
|
|
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
|
|
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
|
|
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
|
|
return SuperPointKeypointDescriptionOutput(
|
|
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
|
|
)
|
|
|
|
|
|
@require_torch
|
|
@require_vision
|
|
class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|
image_processing_class = SuperPointImageProcessor if is_vision_available() else None
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self.image_processor_tester = SuperPointImageProcessingTester(self)
|
|
|
|
@property
|
|
def image_processor_dict(self):
|
|
return self.image_processor_tester.prepare_image_processor_dict()
|
|
|
|
def test_image_processing(self):
|
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
|
self.assertTrue(hasattr(image_processing, "size"))
|
|
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
|
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
|
self.assertTrue(hasattr(image_processing, "do_grayscale"))
|
|
|
|
def test_image_processor_from_dict_with_kwargs(self):
|
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
|
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
|
|
|
|
image_processor = self.image_processing_class.from_dict(
|
|
self.image_processor_dict, size={"height": 42, "width": 42}
|
|
)
|
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
|
|
|
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
|
|
def test_call_numpy_4_channels(self):
|
|
pass
|
|
|
|
def test_input_image_properly_converted_to_grayscale(self):
|
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
|
pre_processed_images = image_processor.preprocess(image_inputs)
|
|
for image in pre_processed_images["pixel_values"]:
|
|
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
|
|
|
@require_torch
|
|
def test_post_processing_keypoint_detection(self):
|
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
|
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
|
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
|
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
|
|
|
|
def check_post_processed_output(post_processed_output, image_size):
|
|
for post_processed_output, image_size in zip(post_processed_output, image_size):
|
|
self.assertTrue("keypoints" in post_processed_output)
|
|
self.assertTrue("descriptors" in post_processed_output)
|
|
self.assertTrue("scores" in post_processed_output)
|
|
keypoints = post_processed_output["keypoints"]
|
|
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
|
|
keypoints[:, 1] <= image_size[0]
|
|
)
|
|
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
|
|
self.assertTrue(all_below_image_size)
|
|
self.assertTrue(all_above_zero)
|
|
|
|
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
|
|
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
|
|
|
|
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
|
|
|
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
|
|
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
|
|
|
|
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|