Fix max_size parameter handling in all DETR image processors

- Fix ConditionalDetrImageProcessor.from_dict to handle max_size correctly
- Fix DetrImageProcessor.from_dict to handle max_size correctly
- Fix DeformableDetrImageProcessor.from_dict to handle max_size correctly
- Fix preprocess methods in all DETR variants to handle max_size properly
- Add comprehensive test verification script
- Ensure max_size is properly converted to longest_edge in size dict
- Handle both integer size and dict size with max_size parameter
- All fixes maintain backward compatibility with deprecation warnings

Fixes issue where max_size parameter would incorrectly overwrite
existing size settings instead of being properly integrated.
This commit is contained in:
nck90 2025-06-26 10:26:07 +09:00
parent 5e6ee58996
commit db9e5b98fe
10 changed files with 1920 additions and 35 deletions

View File

@ -969,7 +969,30 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
max_size = kwargs.pop("max_size")
# Check for size in both image_processor_dict and kwargs
size = kwargs.get("size", image_processor_dict.get("size"))
if size is not None:
# If size is an integer, convert to shortest_edge dict
if isinstance(size, int):
size = {"shortest_edge": size}
# If size is a dict but missing longest_edge, add it
elif isinstance(size, dict) and "longest_edge" not in size:
size = dict(size) # Make a copy
if isinstance(size, dict) and "longest_edge" not in size:
size["longest_edge"] = max_size
# Update both locations if size was in kwargs
if "size" in kwargs:
kwargs["size"] = size
else:
image_processor_dict["size"] = size
else:
# If no size provided, create default size with max_size
image_processor_dict["size"] = {"shortest_edge": 800, "longest_edge": max_size}
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
@ -1433,8 +1456,11 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
else:
# If size is already provided, we need to handle max_size appropriately
if isinstance(size, dict) and "longest_edge" not in size:
size = get_size_dict(size, max_size=max_size, default_to_square=False)
size = dict(size) # Make a copy to avoid modifying the original
size["longest_edge"] = max_size
# If size already has longest_edge, the max_size is ignored (deprecated behavior)
else:
max_size = None if size is None else 1333
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
@ -1650,11 +1676,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
# Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr
def post_process_object_detection(
self,
outputs,
threshold: float = 0.5,
target_sizes: Union[TensorType, list[tuple]] = None,
top_k: int = 100,
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,

View File

@ -946,7 +946,30 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
max_size = kwargs.pop("max_size")
# Check for size in both image_processor_dict and kwargs
size = kwargs.get("size", image_processor_dict.get("size"))
if size is not None:
# If size is an integer, convert to shortest_edge dict
if isinstance(size, int):
size = {"shortest_edge": size}
# If size is a dict but missing longest_edge, add it
elif isinstance(size, dict) and "longest_edge" not in size:
size = dict(size) # Make a copy
if isinstance(size, dict) and "longest_edge" not in size:
size["longest_edge"] = max_size
# Update both locations if size was in kwargs
if "size" in kwargs:
kwargs["size"] = size
else:
image_processor_dict["size"] = size
else:
# If no size provided, create default size with max_size
image_processor_dict["size"] = {"shortest_edge": 800, "longest_edge": max_size}
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
@ -1034,11 +1057,17 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size:
new_size = get_resize_output_image_size(
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
image,
size["shortest_edge"],
size["longest_edge"],
input_data_format=input_data_format,
)
elif "max_height" in size and "max_width" in size:
new_size = get_image_size_for_max_height_width(
image, size["max_height"], size["max_width"], input_data_format=input_data_format
image,
size["max_height"],
size["max_width"],
input_data_format=input_data_format,
)
elif "height" in size and "width" in size:
new_size = (size["height"], size["width"])
@ -1098,7 +1127,12 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
return rescale(
image,
rescale_factor,
data_format=data_format,
input_data_format=input_data_format,
)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
@ -1182,7 +1216,11 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
)
if annotation is not None:
annotation = self._update_annotation_for_padded_image(
annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
annotation,
(input_height, input_width),
(output_height, output_width),
padding,
update_bboxes,
)
return padded_image, annotation
@ -1258,7 +1296,11 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
if return_pixel_mask:
masks = [
make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
make_pixel_mask(
image=image,
output_size=padded_size,
input_data_format=input_data_format,
)
for image in images
]
data["pixel_mask"] = masks
@ -1391,6 +1433,8 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
size = dict(size) # Make a copy to avoid modifying the original
size["longest_edge"] = max_size
# If size already has longest_edge, the max_size is ignored (deprecated behavior)
else:
max_size = None if size is None else 1333
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
@ -1415,7 +1459,10 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self._valid_processor_keys,
)
# Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
validate_preprocess_arguments(

View File

@ -950,7 +950,30 @@ class DetrImageProcessor(BaseImageProcessor):
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
max_size = kwargs.pop("max_size")
# Check for size in both image_processor_dict and kwargs
size = kwargs.get("size", image_processor_dict.get("size"))
if size is not None:
# If size is an integer, convert to shortest_edge dict
if isinstance(size, int):
size = {"shortest_edge": size}
# If size is a dict but missing longest_edge, add it
elif isinstance(size, dict) and "longest_edge" not in size:
size = dict(size) # Make a copy
if isinstance(size, dict) and "longest_edge" not in size:
size["longest_edge"] = max_size
# Update both locations if size was in kwargs
if "size" in kwargs:
kwargs["size"] = size
else:
image_processor_dict["size"] = size
else:
# If no size provided, create default size with max_size
image_processor_dict["size"] = {"shortest_edge": 800, "longest_edge": max_size}
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,8 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
images: list[np.ndarray],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> list[int]:
"""
Get the maximum height and width across all images in a batch.
@ -90,7 +91,9 @@ def get_max_height_width(
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
image: np.ndarray,
output_size: tuple[int, int],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
@ -218,7 +221,10 @@ def compute_segments(
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
mask_probs.unsqueeze(0),
size=target_size,
mode="bilinear",
align_corners=False,
)[0]
current_segment_id = 0
@ -551,7 +557,12 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
return rescale(
image,
rescale_factor,
data_format=data_format,
input_data_format=input_data_format,
)
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks(

View File

@ -77,7 +77,8 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
images: list[np.ndarray],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> list[int]:
"""
Get the maximum height and width across all images in a batch.
@ -96,7 +97,9 @@ def get_max_height_width(
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
image: np.ndarray,
output_size: tuple[int, int],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
@ -224,7 +227,10 @@ def compute_segments(
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
mask_probs.unsqueeze(0),
size=target_size,
mode="bilinear",
align_corners=False,
)[0]
current_segment_id = 0
@ -555,7 +561,12 @@ class MaskFormerImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
return rescale(
image,
rescale_factor,
data_format=data_format,
input_data_format=input_data_format,
)
def convert_segmentation_map_to_binary_masks(
self,

View File

@ -74,7 +74,8 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
images: list[np.ndarray],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> list[int]:
"""
Get the maximum height and width across all images in a batch.
@ -93,7 +94,9 @@ def get_max_height_width(
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
image: np.ndarray,
output_size: tuple[int, int],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
@ -221,7 +224,10 @@ def compute_segments(
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
mask_probs.unsqueeze(0),
size=target_size,
mode="bilinear",
align_corners=False,
)[0]
current_segment_id = 0
@ -558,7 +564,12 @@ class OneFormerImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
return rescale(
image,
rescale_factor,
data_format=data_format,
input_data_format=input_data_format,
)
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks(

View File

@ -238,7 +238,8 @@ def max_across_indices(values: Iterable[Any]) -> list[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
images: list[np.ndarray],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> list[int]:
"""
Get the maximum height and width across all images in a batch.
@ -257,7 +258,9 @@ def get_max_height_width(
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
image: np.ndarray,
output_size: tuple[int, int],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
@ -537,16 +540,26 @@ class RTDetrImageProcessor(BaseImageProcessor):
"Please specify in `size['longest_edge'] instead`.",
)
max_size = kwargs.pop("max_size")
# If size is already a dict but missing longest_edge, add it from max_size
if isinstance(size, dict) and "longest_edge" not in size:
size = dict(size) # Make a copy
size["longest_edge"] = max_size
else:
max_size = None
size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size:
new_size = get_resize_output_image_size(
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
image,
size["shortest_edge"],
size["longest_edge"],
input_data_format=input_data_format,
)
elif "max_height" in size and "max_width" in size:
new_size = get_image_size_for_max_height_width(
image, size["max_height"], size["max_width"], input_data_format=input_data_format
image,
size["max_height"],
size["max_width"],
input_data_format=input_data_format,
)
elif "height" in size and "width" in size:
new_size = (size["height"], size["width"])
@ -606,7 +619,12 @@ class RTDetrImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
return rescale(
image,
rescale_factor,
data_format=data_format,
input_data_format=input_data_format,
)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
@ -690,7 +708,11 @@ class RTDetrImageProcessor(BaseImageProcessor):
)
if annotation is not None:
annotation = self._update_annotation_for_padded_image(
annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
annotation,
(input_height, input_width),
(output_height, output_width),
padding,
update_bboxes,
)
return padded_image, annotation
@ -766,7 +788,11 @@ class RTDetrImageProcessor(BaseImageProcessor):
if return_pixel_mask:
masks = [
make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
make_pixel_mask(
image=image,
output_size=padded_size,
input_data_format=input_data_format,
)
for image in images
]
data["pixel_mask"] = masks

67
test_fix_verification.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
def test_all_fixes():
print("🧪 Testing all DETR max_size parameter fixes...")
try:
from transformers.models.conditional_detr.image_processing_conditional_detr import ConditionalDetrImageProcessor
from transformers.models.detr.image_processing_detr import DetrImageProcessor
from transformers.models.deformable_detr.image_processing_deformable_detr import DeformableDetrImageProcessor
processors = [
("ConditionalDetr", ConditionalDetrImageProcessor),
("Detr", DetrImageProcessor),
("DeformableDetr", DeformableDetrImageProcessor)
]
for name, ProcessorClass in processors:
print(f"\n🔧 Testing {name}ImageProcessor...")
# Test 1: from_dict with size=42, max_size=84
processor = ProcessorClass.from_dict({
"do_resize": True,
"do_normalize": True,
"do_pad": True,
}, size=42, max_size=84)
expected = {"shortest_edge": 42, "longest_edge": 84}
actual = processor.size
assert actual == expected, f"❌ Test 1 failed: expected {expected}, got {actual}"
print(f"✅ Test 1 passed: from_dict(size=42, max_size=84) = {actual}")
# Test 2: from_dict with size dict without longest_edge + max_size
processor = ProcessorClass.from_dict({
"do_resize": True,
"do_normalize": True,
"do_pad": True,
"size": {"shortest_edge": 100}
}, max_size=200)
expected = {"shortest_edge": 100, "longest_edge": 200}
actual = processor.size
assert actual == expected, f"❌ Test 2 failed: expected {expected}, got {actual}"
print(f"✅ Test 2 passed: size without longest_edge + max_size = {actual}")
# Test 3: init with max_size only
processor = ProcessorClass(max_size=500)
expected = {"shortest_edge": 800, "longest_edge": 500}
actual = processor.size
assert actual == expected, f"❌ Test 3 failed: expected {expected}, got {actual}"
print(f"✅ Test 3 passed: init(max_size=500) = {actual}")
print(f"🎉 All tests passed for {name}ImageProcessor!")
print("\n🌟 All DETR image processors work correctly!")
return True
except Exception as e:
print(f"❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_all_fixes()
sys.exit(0 if success else 1)

View File

@ -640,7 +640,7 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
expected_size = {"shortest_edge": 500, "longest_edge": 800}
self.assertEqual(image_processor.size, expected_size)
# Test 4: from_dict with max_size (using a dict without longest_edge)
# Test 4: from_dict with max_size (using a dict without longest_edge)
test_dict = {k: v for k, v in self.image_processor_dict.items() if k != "size"}
test_dict["size"] = {"shortest_edge": 18} # Only shortest_edge, no longest_edge
image_processor = image_processing_class.from_dict(test_dict, max_size=1100)