transformers/tests/models/superglue/test_image_processing_superglue.py
StevenBucaille abe57b6f17
Add SuperGlue model (#29886)
* Initial commit with template code generated by transformers-cli

* Multiple additions to SuperGlue implementation :

- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass

* Few changes for SuperGlue

* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference

* Reverted unintentional change

* Multiple changes :
 - Added SuperGlue to a bunch of places
 - Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
 - Added testing images

* Moved things in init files

* Added docs (to be finished depending on the final implementation)

* Added necessary imports and some doc

* Removed unnecessary import

* Fixed make fix-copies bug and ran it

* Deleted SuperGlueModel
Fixed convert script

* Added SuperGlueImageProcessor

* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences

* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances

* Added initial tests for SuperGlueImageProcessor

* Added AutoModelForImageMatching in missing places and tests

* Fixed keypoint_detector_output instructions

* Fix style

* Adapted to latest main changes

* Added integration test

* Fixed bugs to pass tests

* Added keypoints returned by keypoint detector in the output of SuperGlue

* Added doc to SuperGlue

* SuperGlue returning all attention and hidden states for a fixed number of keypoints

* Make style

* Changed SuperGlueImageProcessor tests

* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly

This reverts commit 5b3b669c

* Added back hidden_states and attentions masked outputs with tests

* Renamed ImageMatching occurences into KeypointMatching

* Changed SuperGlueImageProcessor to raise error when batch_size is not even

* Added docs and clarity to hidden state and attention grouping function

* Fixed some code and done refactoring

* Fixed typo in SuperPoint output doc

* Fixed some of the formatting and variable naming problems

* Removed useless function call

* Removed AutoModelForKeypointMatching

* Fixed SuperGlueImageProcessor to only accept paris of images

* Added more fixes to SuperGlueImageProcessor

* Simplified the batching of attention and hidden states

* Simplified stack functions

* Moved attention instructions into class

* Removed unused do_batch_norm argument

* Moved weight initialization to the proper place

* Replaced deepcopy for instantiation

* Fixed small bug

* Changed from stevenbucaille to magic-leap repo

* Renamed London Bridge images to Tower Bridge

* Fixed formatting

* Renamed remaining "london" to "tower"

* Apply suggestions from code review

Small changes in the docs

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Added AutoModelForKeypointMatching

* Changed images used in example

* Several changes to image_processing_superglue and style

* Fixed resample type hint

* Changed SuperGlueImageProcessor and added test case for list of 2 images

* Changed list_of_tuples implementation

* Fix in dummy objects

* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring

* Added missing docstring

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Moved forward block at bottom

* Added docstring to forward method

* Added docstring to match_image_pair method

* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature

* Removed AutoModelForKeypointMatching

* Removed image fixtures and added load_dataset

* Added padding of images in SuperGlueImageProcessor

* Cleaned up convert_superglue_to_hf script

* Added missing docs and fixed unused argument

* Fixed SuperGlueImageProcessor tests

* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape

* Added SuperGlueForKeypointMatching back to modeling_auto

* Fixed image processor padding test

* Changed SuperGlue docs

* changes:
 - Abstraction to batch, concat and stack of inconsistent tensors
 - Changed conv1d's to linears to match standard attention implementations
 - Renamed all tensors to be tensor0 and not tensor_0 and be consistent
 - Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
 - Various changes in docs, etc

* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code

* Formatting changes

* Reverted conv1d to linear conversion because of numerical differences

* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib

* fix: removed unnecessary test

* chore: removed commented code and added back hidden states transpositions

* chore: changed from "inconsistent" to "ragged" function names as suggested

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* docs: applied suggestions

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* docs: updated to display matched output

* chore: applied suggestion for check_image_pairs_input function

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function

* tests: simplified tests for image input format and shapes

* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly

* feat: several changes to address comments

Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts

Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()

Updated test to reflect changes in SuperGlue model

* refactor: changed validate_and_format_image_pairs function with clarity

* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class

* fix: fixed forgotten init weight change from last commit

* fix: fixed rebase mistake

* fix: removed leftover commented code

* fix: added typehint and changed some of arguments default values

* fix: fixed attribute default values for SuperGlueConfig

* feat: added SuperGlueImageProcessor post process keypoint matching method with tests

* fix: fixed SuperGlue attention and hidden state tuples aggregation

* chore: fixed mask optionality and reordered tensor reshapes to be cleaner

* chore: fixed docs and error message returned in validate_and_format_image_pairs function

* fix: fixed returned keypoints to be the ones that SuperPoint returns

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)

* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement

* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules

* WIP: implement Attention from an existing class (like BERT)

* docs: Changed docs to include more appealing matching plot

* WIP: Implement Attention

* chore: minor typehint change

* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation

* Revert "Fixed typo in SuperPoint output doc"

This reverts commit 2120390e82.

* chore: added comments in SuperGlueImageProcessor

* chore: changed SuperGlue organization HF repo to magic-leap-community

* [run-slow] refactor: small change in layer instantiation

* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community

* [run-slow] chore: make style

* chore: update image matching fixture dataset HF repository

* [run-slow] superglue

* tests: overwriting test_batching_equivalence

* [run-slow] superglue

* tests: changed test to cope with value changing depending on cuda version

* [run-slow] superglue

* tests: changed matching_threshold value

* [run-slow] superglue

* [run-slow] superglue

* tests: changed tests for integration

* [run-slow] superglue

* fix: Changed tensor view and permutations to match original implementation results

* fix: updated convert script and integration test to include last change in model

* fix: increase tolerance for CUDA variances

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* [run-slow] superglue

* chore: removed blank whitespaces

* [run-slow] superglue

* Revert SuperPoint image processor accident changes

* [run-slow] superglue

* refactor: reverted copy from BERT class

* tests: lower the tolerance in integration tests for SuperGlue

* [run-slow] superglue

* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors

* [run-slow] superglue

* fix: fixed imports in SuperGlue files

* chore: changed do_grayscale SuperGlueImageProcessing default value to True

* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor

* fix: set matching_threshold default value to 0.0 instead of 0.2

* feat: added matching_threshold to post_process_keypoint_matching method

* docs: update superglue.md to include matching_threshold parameter

* docs: updated SuperGlueConfig docstring for matching_threshold default value

* refactor: removed unnecessary parameters in SuperGlueConfig

* fix: changed from matching_threshold to threshold

* fix: re-revert changes to make SuperGlue attention classes copies of BERT

* [run-slow] superglue

* fix: added missing device argument in post_processing method

* [run-slow] superglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)

* fix: add device to image_sizes tensor instantiation

* tests: added checks on do_grayscale test

* chore: reordered and added Optional typehint to KeypointMatchingOutput

* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods

* LightGluePR suggestions:
- use batching to get keypoint detection

* refactor: processing images done in 1 for loop instead of 4

* fix: use @ instead of torch.einsum for scores computation

* style: added #fmt skip to long tensor values

* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones

* refactor: prepare_imgs

* refactor: simplified `validate_and_format_image_pairs`

* docs: fixed doc

---------

Co-authored-by: steven <steven.bucaillle@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
2025-01-20 10:32:39 +00:00

385 lines
19 KiB
Python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
prepare_image_inputs,
)
if is_torch_available():
import numpy as np
import torch
from transformers.models.superglue.modeling_superglue import KeypointMatchingOutput
if is_vision_available():
from transformers import SuperGlueImageProcessor
def random_array(size):
return np.random.randint(255, size=size)
def random_tensor(size):
return torch.rand(size)
class SuperGlueImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=6,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_grayscale=True,
):
size = size if size is not None else {"height": 480, "width": 640}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_grayscale = do_grayscale
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_grayscale": self.do_grayscale,
}
def expected_output_image_shape(self, images):
return 2, self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False, pairs=True, batch_size=None):
batch_size = batch_size if batch_size is not None else self.batch_size
image_inputs = prepare_image_inputs(
batch_size=batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
if pairs:
image_inputs = [image_inputs[i : i + 2] for i in range(0, len(image_inputs), 2)]
return image_inputs
def prepare_keypoint_matching_output(self, pixel_values):
max_number_keypoints = 50
batch_size = len(pixel_values)
mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int)
keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2))
matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int)
scores = torch.zeros((batch_size, 2, max_number_keypoints))
for i in range(batch_size):
random_number_keypoints0 = np.random.randint(10, max_number_keypoints)
random_number_keypoints1 = np.random.randint(10, max_number_keypoints)
random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1))
mask[i, 0, :random_number_keypoints0] = 1
mask[i, 1, :random_number_keypoints1] = 1
keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2))
keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2))
random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches]
random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches]
matches[i, 0, random_matches_indices1] = random_matches_indices0
matches[i, 1, random_matches_indices0] = random_matches_indices1
scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,))
scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,))
return KeypointMatchingOutput(mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores)
@require_torch
@require_vision
class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SuperGlueImageProcessor if is_vision_available() else None
def setUp(self) -> None:
super().setUp()
self.image_processor_tester = SuperGlueImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processing(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_grayscale"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 42, "width": 42}
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
def test_call_numpy_4_channels(self):
pass
def test_number_and_format_of_images_in_input(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
# Cases where the number of images and the format of lists in the input is correct
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=2)
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=2)
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=4)
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual((2, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=6)
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual((3, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape))
# Cases where the number of images or the format of lists in the input is incorrect
## List of 4 images
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=4)
with self.assertRaises(ValueError) as cm:
image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual(ValueError, cm.exception.__class__)
## List of 3 images
image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=3)
with self.assertRaises(ValueError) as cm:
image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual(ValueError, cm.exception.__class__)
## List of 2 pairs and 1 image
image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=3)
with self.assertRaises(ValueError) as cm:
image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual(ValueError, cm.exception.__class__)
@parameterized.expand(
[
([random_array((3, 100, 200)), random_array((3, 100, 200))], (1, 2, 3, 480, 640)),
([[random_array((3, 100, 200)), random_array((3, 100, 200))]], (1, 2, 3, 480, 640)),
([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)),
([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)),
],
)
def test_valid_image_shape_in_input(self, image_input, output):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_processed = image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual(output, tuple(image_processed["pixel_values"].shape))
@parameterized.expand(
[
(random_array((3, 100, 200)),),
([random_array((3, 100, 200))],),
(random_array((1, 3, 100, 200)),),
([[random_array((3, 100, 200))]],),
([[random_array((3, 100, 200))], [random_array((3, 100, 200))]],),
([random_array((1, 3, 100, 200)), random_array((1, 3, 100, 200))],),
(random_array((1, 1, 3, 100, 200)),),
],
)
def test_invalid_image_shape_in_input(self, image_input):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
with self.assertRaises(ValueError) as cm:
image_processor.preprocess(image_input, return_tensors="pt")
self.assertEqual(ValueError, cm.exception.__class__)
def test_input_images_properly_paired(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="np")
self.assertEqual(len(pre_processed_images["pixel_values"].shape), 5)
self.assertEqual(pre_processed_images["pixel_values"].shape[1], 2)
def test_input_not_paired_images_raises_error(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(pairs=False)
with self.assertRaises(ValueError):
image_processor.preprocess(image_inputs[0])
def test_input_image_properly_converted_to_grayscale(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs)
for image_pair in pre_processed_images["pixel_values"]:
for image in image_pair:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
def test_call_numpy(self):
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image_pair in image_pairs:
self.assertEqual(len(image_pair), 2)
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
# Test with 2 images
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test with list of pairs
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
# Test without paired images
image_pairs = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, numpify=True, pairs=False
)
with self.assertRaises(ValueError):
image_processing(image_pairs, return_tensors="pt").pixel_values
def test_call_pil(self):
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image_pair in image_pairs:
self.assertEqual(len(image_pair), 2)
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
# Test with 2 images
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test with list of pairs
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
# Test without paired images
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, pairs=False)
with self.assertRaises(ValueError):
image_processing(image_pairs, return_tensors="pt").pixel_values
def test_call_pytorch(self):
# Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image_pair in image_pairs:
self.assertEqual(len(image_pair), 2)
expected_batch_size = int(self.image_processor_tester.batch_size / 2)
# Test with 2 images
encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test with list of pairs
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs)
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
# Test without paired images
image_pairs = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, torchify=True, pairs=False
)
with self.assertRaises(ValueError):
image_processing(image_pairs, return_tensors="pt").pixel_values
def test_image_processor_with_list_of_two_images(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_pairs = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, numpify=True, batch_size=2, pairs=False
)
self.assertEqual(len(image_pairs), 2)
self.assertTrue(isinstance(image_pairs[0], np.ndarray))
self.assertTrue(isinstance(image_pairs[1], np.ndarray))
expected_batch_size = 1
encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0])
self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape))
@require_torch
def test_post_processing_keypoint_matching(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_matching_output(**pre_processed_images)
def check_post_processed_output(post_processed_output, image_pair_size):
for post_processed_output, (image_size0, image_size1) in zip(post_processed_output, image_pair_size):
self.assertTrue("keypoints0" in post_processed_output)
self.assertTrue("keypoints1" in post_processed_output)
self.assertTrue("matching_scores" in post_processed_output)
keypoints0 = post_processed_output["keypoints0"]
keypoints1 = post_processed_output["keypoints1"]
all_below_image_size0 = torch.all(keypoints0[:, 0] <= image_size0[1]) and torch.all(
keypoints0[:, 1] <= image_size0[0]
)
all_below_image_size1 = torch.all(keypoints1[:, 0] <= image_size1[1]) and torch.all(
keypoints1[:, 1] <= image_size1[0]
)
all_above_zero0 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0)
all_above_zero1 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0)
self.assertTrue(all_below_image_size0)
self.assertTrue(all_below_image_size1)
self.assertTrue(all_above_zero0)
self.assertTrue(all_above_zero1)
all_scores_different_from_minus_one = torch.all(post_processed_output["matching_scores"] != -1)
self.assertTrue(all_scores_different_from_minus_one)
tuple_image_sizes = [
((image_pair[0].size[0], image_pair[0].size[1]), (image_pair[1].size[0], image_pair[1].size[1]))
for image_pair in image_inputs
]
tuple_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tuple_image_sizes)
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
tensor_image_sizes = torch.tensor(
[(image_pair[0].size, image_pair[1].size) for image_pair in image_inputs]
).flip(2)
tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)