mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Fast Conditional-DETR Processor (#37071)
* Add Fast Conditional-DETR Processor * Update image_processing_conditional_detr_fast.py * Add modular_conditional_detr.py * Update image_processing_conditional_detr_fast.py * Update tests * make fix --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
4f1dbe8152
commit
51f544a4d4
@ -48,6 +48,11 @@ This model was contributed by [DepuMeng](https://huggingface.co/DepuMeng). The o
|
||||
|
||||
[[autodoc]] ConditionalDetrImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ConditionalDetrImageProcessorFast
|
||||
|
||||
[[autodoc]] ConditionalDetrImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_object_detection
|
||||
- post_process_instance_segmentation
|
||||
- post_process_semantic_segmentation
|
||||
|
@ -43,6 +43,11 @@ alt="描画" width="600"/>
|
||||
|
||||
[[autodoc]] ConditionalDetrImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ConditionalDetrImageProcessorFast
|
||||
|
||||
[[autodoc]] ConditionalDetrImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_object_detection
|
||||
- post_process_instance_segmentation
|
||||
- post_process_semantic_segmentation
|
||||
|
@ -67,7 +67,7 @@ else:
|
||||
("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
|
||||
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("conditional_detr", ("ConditionalDetrImageProcessor",)),
|
||||
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
||||
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from .configuration_conditional_detr import *
|
||||
from .feature_extraction_conditional_detr import *
|
||||
from .image_processing_conditional_detr import *
|
||||
from .image_processing_conditional_detr_fast import *
|
||||
from .modeling_conditional_detr import *
|
||||
else:
|
||||
import sys
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,137 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast
|
||||
|
||||
from ...image_transforms import (
|
||||
center_to_corners_format,
|
||||
)
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ConditionalDetrImageProcessorFast(DetrImageProcessorFast):
|
||||
def post_process(self, outputs, target_sizes):
|
||||
"""
|
||||
Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax).
|
||||
Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`ConditionalDetrObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
||||
image size (before any data augmentation). For visualization, this should be the image size after data
|
||||
augment, but before padding.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
logging.warning_once(
|
||||
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
||||
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
|
||||
)
|
||||
|
||||
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if len(out_logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
||||
|
||||
return results
|
||||
|
||||
def post_process_object_detection(
|
||||
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
|
||||
top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`ConditionalDetrObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*):
|
||||
Score threshold to keep object detection predictions.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
top_k (`int`, *optional*, defaults to 100):
|
||||
Keep only top k bounding boxes before filtering by thresholding.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if target_sizes is not None:
|
||||
if len(out_logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
prob = out_logits.sigmoid()
|
||||
prob = prob.view(out_logits.shape[0], -1)
|
||||
k_value = min(top_k, prob.size(1))
|
||||
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
if target_sizes is not None:
|
||||
if isinstance(target_sizes, List):
|
||||
img_h = torch.Tensor([i[0] for i in target_sizes])
|
||||
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||
else:
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = []
|
||||
for s, l, b in zip(scores, labels, boxes):
|
||||
score = s[s > threshold]
|
||||
label = l[s > threshold]
|
||||
box = b[s > threshold]
|
||||
results.append({"scores": score, "labels": label, "boxes": box})
|
||||
|
||||
return results
|
||||
|
||||
def post_process_segmentation():
|
||||
raise NotImplementedError("Segmentation post-processing is not implemented for Conditional DETR yet.")
|
||||
|
||||
def post_process_instance():
|
||||
raise NotImplementedError("Instance post-processing is not implemented for Conditional DETR yet.")
|
||||
|
||||
def post_process_panoptic():
|
||||
raise NotImplementedError("Panoptic post-processing is not implemented for Conditional DETR yet.")
|
||||
|
||||
|
||||
__all__ = ["ConditionalDetrImageProcessorFast"]
|
@ -20,7 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -33,6 +33,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import ConditionalDetrImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import ConditionalDetrImageProcessorFast
|
||||
|
||||
|
||||
class ConditionalDetrImageProcessingTester:
|
||||
def __init__(
|
||||
@ -132,6 +135,7 @@ class ConditionalDetrImageProcessingTester:
|
||||
@require_vision
|
||||
class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = ConditionalDetrImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = ConditionalDetrImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -142,23 +146,25 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 1333})
|
||||
self.assertEqual(image_processor.do_pad, True)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 1333})
|
||||
self.assertEqual(image_processor.do_pad, True)
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
|
||||
self.assertEqual(image_processor.do_pad, False)
|
||||
image_processor = image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
|
||||
self.assertEqual(image_processor.do_pad, False)
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_detection_annotations(self):
|
||||
@ -169,40 +175,41 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
|
||||
|
||||
target = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# encode them
|
||||
image_processing = ConditionalDetrImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
|
||||
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# encode them
|
||||
image_processing = image_processing_class.from_pretrained("microsoft/conditional-detr-resnet-50")
|
||||
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
|
||||
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# verify area
|
||||
expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
|
||||
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
|
||||
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
|
||||
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
|
||||
# verify area
|
||||
expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
|
||||
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
|
||||
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
|
||||
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_panoptic_annotations(self):
|
||||
@ -215,43 +222,45 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
|
||||
|
||||
masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic")
|
||||
|
||||
# encode them
|
||||
image_processing = ConditionalDetrImageProcessor(format="coco_panoptic")
|
||||
encoding = image_processing(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# encode them
|
||||
image_processing = image_processing_class(format="coco_panoptic")
|
||||
encoding = image_processing(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
|
||||
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# verify area
|
||||
expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
|
||||
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
|
||||
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
|
||||
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify masks
|
||||
expected_masks_sum = 822873
|
||||
self.assertEqual(encoding["labels"][0]["masks"].sum().item(), expected_masks_sum)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
|
||||
# verify area
|
||||
expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
|
||||
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
|
||||
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
|
||||
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify masks
|
||||
expected_masks_sum = 822873
|
||||
relative_error = torch.abs(encoding["labels"][0]["masks"].sum() - expected_masks_sum) / expected_masks_sum
|
||||
self.assertTrue(relative_error < 1e-3)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
|
||||
|
||||
@slow
|
||||
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_batched_coco_detection_annotations with Detr->ConditionalDetr, facebook/detr-resnet-50 ->microsoft/conditional-detr-resnet-50
|
||||
|
Loading…
Reference in New Issue
Block a user