mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
add Sam2VideoSessionState + fast image proc + video proc
This commit is contained in:
parent
79055ad65f
commit
c3330c677a
@ -142,7 +142,7 @@ else:
|
||||
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
||||
("sam", ("SamImageProcessor",)),
|
||||
("sam2", ("Sam2ImageProcessor",)),
|
||||
("sam2", ("Sam2ImageProcessor", "Sam2ImageProcessorFast")),
|
||||
("sam_hq", ("SamImageProcessor",)),
|
||||
("segformer", ("SegformerImageProcessor",)),
|
||||
("seggpt", ("SegGptImageProcessor",)),
|
||||
|
@ -54,6 +54,7 @@ else:
|
||||
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
||||
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
||||
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
||||
("sam2", "Sam2VideoProcessor"),
|
||||
("smolvlm", "SmolVLMVideoProcessor"),
|
||||
("video_llava", "VideoLlavaVideoProcessor"),
|
||||
("vjepa2", "VJEPA2VideoProcessor"),
|
||||
|
@ -20,8 +20,10 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_sam2 import *
|
||||
from .image_processing_sam2 import *
|
||||
from .image_processing_sam2_fast import *
|
||||
from .modeling_sam2 import *
|
||||
from .processing_sam2 import *
|
||||
from .video_processing_sam2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -30,13 +30,14 @@ from PIL import Image
|
||||
from transformers import (
|
||||
Sam2Config,
|
||||
Sam2ImageEncoderConfig,
|
||||
Sam2ImageProcessor,
|
||||
Sam2ImageProcessorFast,
|
||||
Sam2MaskDecoderConfig,
|
||||
Sam2MemoryAttentionConfig,
|
||||
Sam2MemoryEncoderConfig,
|
||||
Sam2Model,
|
||||
Sam2Processor,
|
||||
Sam2PromptEncoderConfig,
|
||||
Sam2VideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
@ -54,13 +55,28 @@ def get_config(model_name):
|
||||
memory_attention_config = Sam2MemoryAttentionConfig()
|
||||
memory_encoder_config = Sam2MemoryEncoderConfig()
|
||||
elif "sam2.1_hiera_base_plus" in model_name:
|
||||
image_encoder_config = Sam2ImageEncoderConfig(hidden_size=112, num_heads=2, stages=(2, 3, 16, 3), global_attention_blocks=(12, 16, 20), window_positional_embedding_background_size=(14, 14), backbone_channel_list=[896, 448, 224, 112])
|
||||
image_encoder_config = Sam2ImageEncoderConfig(
|
||||
hidden_size=112,
|
||||
num_heads=2,
|
||||
stages=(2, 3, 16, 3),
|
||||
global_attention_blocks=(12, 16, 20),
|
||||
window_positional_embedding_background_size=(14, 14),
|
||||
backbone_channel_list=[896, 448, 224, 112],
|
||||
)
|
||||
prompt_encoder_config = Sam2PromptEncoderConfig()
|
||||
mask_decoder_config = Sam2MaskDecoderConfig()
|
||||
memory_attention_config = Sam2MemoryAttentionConfig()
|
||||
memory_encoder_config = Sam2MemoryEncoderConfig()
|
||||
elif "sam2.1_hiera_large" in model_name:
|
||||
image_encoder_config = Sam2ImageEncoderConfig(hidden_size=144, num_heads=2, stages=(2, 6, 36, 4), global_attention_blocks=(23, 33, 43), window_positional_embedding_background_size=(7, 7), window_spec=(8, 4, 16, 8), backbone_channel_list=[1152, 576, 288, 144])
|
||||
image_encoder_config = Sam2ImageEncoderConfig(
|
||||
hidden_size=144,
|
||||
num_heads=2,
|
||||
stages=(2, 6, 36, 4),
|
||||
global_attention_blocks=(23, 33, 43),
|
||||
window_positional_embedding_background_size=(7, 7),
|
||||
window_spec=(8, 4, 16, 8),
|
||||
backbone_channel_list=[1152, 576, 288, 144],
|
||||
)
|
||||
prompt_encoder_config = Sam2PromptEncoderConfig()
|
||||
mask_decoder_config = Sam2MaskDecoderConfig()
|
||||
memory_attention_config = Sam2MemoryAttentionConfig()
|
||||
@ -197,8 +213,9 @@ def convert_sam2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, pu
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
state_dict = replace_keys(state_dict)
|
||||
|
||||
image_processor = Sam2ImageProcessor()
|
||||
processor = Sam2Processor(image_processor=image_processor)
|
||||
image_processor = Sam2ImageProcessorFast()
|
||||
video_processor = Sam2VideoProcessor()
|
||||
processor = Sam2Processor(image_processor=image_processor, video_processor=video_processor)
|
||||
hf_model = Sam2Model(config)
|
||||
hf_model.eval()
|
||||
|
||||
@ -292,6 +309,10 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
hf_model_name = args.model_name.replace("_", "-")
|
||||
checkpoint_path = hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name}.pt") if args.checkpoint_path is None else args.checkpoint_path
|
||||
checkpoint_path = (
|
||||
hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name}.pt")
|
||||
if args.checkpoint_path is None
|
||||
else args.checkpoint_path
|
||||
)
|
||||
|
||||
convert_sam2_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
||||
|
128
src/transformers/models/sam2/image_processing_sam2_fast.py
Normal file
128
src/transformers/models/sam2/image_processing_sam2_fast.py
Normal file
@ -0,0 +1,128 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SAM2."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn import functional as F_t
|
||||
|
||||
|
||||
class Sam2ImageProcessorFastKwargs(DefaultFastImageProcessorKwargs):
|
||||
do_pad: bool
|
||||
mask_pad_size: SizeDict
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Sam2ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_DEFAULT_MEAN
|
||||
image_std = IMAGENET_DEFAULT_STD
|
||||
size = {"height": 1024, "width": 1024}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = False
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
size: Optional[SizeDict],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
original_sizes = [image.shape[-2:] for image in images]
|
||||
reshaped_input_sizes = [(size.height, size.width) for _ in range(len(images))]
|
||||
batch_feature = super()._preprocess(images, size=size, return_tensors=return_tensors, **kwargs)
|
||||
batch_feature = BatchFeature(
|
||||
data={
|
||||
"original_sizes": original_sizes,
|
||||
"reshaped_input_sizes": reshaped_input_sizes,
|
||||
**batch_feature.data,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
return batch_feature
|
||||
|
||||
def post_process_masks(
|
||||
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||
):
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
||||
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||||
width) format.
|
||||
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The threshold to use for binarizing the masks.
|
||||
binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the masks.
|
||||
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||
assumed to be the processor's `pad_size`.
|
||||
Returns:
|
||||
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
||||
is given by original_size.
|
||||
"""
|
||||
pad_size = self.size if pad_size is None else pad_size
|
||||
target_image_size = (pad_size["height"], pad_size["width"])
|
||||
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
||||
original_sizes = original_sizes.tolist()
|
||||
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
||||
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
||||
output_masks = []
|
||||
for i, original_size in enumerate(original_sizes):
|
||||
if isinstance(masks[i], np.ndarray):
|
||||
masks[i] = torch.from_numpy(masks[i])
|
||||
elif not isinstance(masks[i], torch.Tensor):
|
||||
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
||||
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
||||
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
||||
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
||||
if binarize:
|
||||
interpolated_mask = interpolated_mask > mask_threshold
|
||||
output_masks.append(interpolated_mask)
|
||||
|
||||
return output_masks
|
||||
|
||||
|
||||
__all__ = ["Sam2ImageProcessorFast"]
|
@ -19,6 +19,7 @@ import collections.abc
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
@ -125,6 +126,95 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
return mask
|
||||
|
||||
|
||||
class Sam2VideoSessionState:
|
||||
images: torch.FloatTensor = None
|
||||
num_frames: int = None
|
||||
offload_video_to_cpu: bool = None
|
||||
offload_state_to_cpu: bool = None
|
||||
video_height: int = None
|
||||
video_width: int = None
|
||||
device: torch.device = None
|
||||
storage_device: torch.device = None
|
||||
point_inputs_per_obj: dict = None
|
||||
mask_inputs_per_obj: dict = None
|
||||
cached_features: dict = None
|
||||
constants: dict = None
|
||||
obj_id_to_idx: dict = None
|
||||
obj_idx_to_id: dict = None
|
||||
obj_ids: list = None
|
||||
output_dict_per_obj: dict = None
|
||||
temp_output_dict_per_obj: dict = None
|
||||
frames_tracked_per_obj: dict = None
|
||||
|
||||
# TODO add async video loading?
|
||||
def __init__(
|
||||
self,
|
||||
video: torch.FloatTensor,
|
||||
video_height: int,
|
||||
video_width: int,
|
||||
offload_video_to_cpu: bool = False,
|
||||
offload_state_to_cpu: bool = False,
|
||||
async_loading_frames: bool = False,
|
||||
):
|
||||
self.images = list(video)
|
||||
self.num_frames = len(video)
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
self.offload_state_to_cpu = offload_state_to_cpu
|
||||
self.async_loading_frames = async_loading_frames
|
||||
self.video_height = video_height
|
||||
self.video_width = video_width
|
||||
self.device = video.device
|
||||
self.storage_device = torch.device("cpu") if offload_state_to_cpu else video.device
|
||||
self.cached_features = {}
|
||||
self.point_inputs_per_obj = {}
|
||||
self.mask_inputs_per_obj = {}
|
||||
self.constants = {}
|
||||
self.obj_id_to_idx = OrderedDict()
|
||||
self.obj_idx_to_id = OrderedDict()
|
||||
self.obj_ids = []
|
||||
self.output_dict_per_obj = {}
|
||||
self.temp_output_dict_per_obj = {}
|
||||
self.frames_tracked_per_obj = {}
|
||||
|
||||
def reset_inference_session(self):
|
||||
self.point_inputs_per_obj.clear()
|
||||
self.mask_inputs_per_obj.clear()
|
||||
self.constants.clear()
|
||||
self.obj_id_to_idx.clear()
|
||||
self.obj_idx_to_id.clear()
|
||||
self.obj_ids.clear()
|
||||
self.output_dict_per_obj.clear()
|
||||
self.temp_output_dict_per_obj.clear()
|
||||
self.frames_tracked_per_obj.clear()
|
||||
|
||||
def _obj_id_to_idx(self, obj_id: int) -> int:
|
||||
"""Map client-side object id to model-side object index."""
|
||||
obj_idx = self.obj_id_to_idx.get(obj_id, None)
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# Add new object
|
||||
obj_idx = len(self.obj_id_to_idx)
|
||||
self.obj_id_to_idx[obj_id] = obj_idx
|
||||
self.obj_idx_to_id[obj_idx] = obj_id
|
||||
self.obj_ids = list(self.obj_id_to_idx)
|
||||
|
||||
# Set up input and output structures for this object
|
||||
self.point_inputs_per_obj[obj_idx] = {}
|
||||
self.mask_inputs_per_obj[obj_idx] = {}
|
||||
self.output_dict_per_obj[obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.temp_output_dict_per_obj[obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.frames_tracked_per_obj[obj_idx] = {}
|
||||
|
||||
return obj_idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sam2ImageEncoderOutput(ModelOutput):
|
||||
"""
|
||||
@ -2366,20 +2456,20 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
# Video Inference specific functions
|
||||
def _obj_idx_to_id(self, inference_state, obj_idx):
|
||||
"""Map model-side object index to client-side object id."""
|
||||
return inference_state["obj_idx_to_id"][obj_idx]
|
||||
return inference_state.obj_idx_to_id[obj_idx]
|
||||
|
||||
def _get_obj_num(self, inference_state):
|
||||
"""Get the total number of unique object ids received so far in this session."""
|
||||
return len(inference_state["obj_idx_to_id"])
|
||||
return len(inference_state.obj_idx_to_id)
|
||||
|
||||
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
||||
"""
|
||||
Resize the object scores to the original video resolution (video_res_masks)
|
||||
and apply non-overlapping constraints for final output.
|
||||
"""
|
||||
device = inference_state["device"]
|
||||
video_H = inference_state["video_height"]
|
||||
video_W = inference_state["video_width"]
|
||||
device = inference_state.device
|
||||
video_H = inference_state.video_height
|
||||
video_W = inference_state.video_width
|
||||
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
||||
if any_res_masks.shape[-2:] == (video_H, video_W):
|
||||
video_res_masks = any_res_masks
|
||||
@ -2415,8 +2505,8 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
# Optionally, we allow consolidating the temporary outputs at the original
|
||||
# video resolution (to provide a better editing experience for mask prompts).
|
||||
if consolidate_at_video_res:
|
||||
consolidated_H = inference_state["video_height"]
|
||||
consolidated_W = inference_state["video_width"]
|
||||
consolidated_H = inference_state.video_height
|
||||
consolidated_W = inference_state.video_width
|
||||
consolidated_mask_key = "pred_masks_video_res"
|
||||
else:
|
||||
consolidated_H = consolidated_W = self.image_size // 4
|
||||
@ -2431,12 +2521,12 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
size=(batch_size, 1, consolidated_H, consolidated_W),
|
||||
fill_value=NO_OBJ_SCORE,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["storage_device"],
|
||||
device=inference_state.storage_device,
|
||||
),
|
||||
}
|
||||
for obj_idx in range(batch_size):
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
|
||||
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
|
||||
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
||||
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
||||
# we fall back and look up its previous output in "output_dict_per_obj".
|
||||
@ -2481,8 +2571,8 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
"""
|
||||
Add new conditioning inputs to a frame and run inference.
|
||||
"""
|
||||
device = inference_state["device"]
|
||||
storage_device = inference_state["storage_device"]
|
||||
device = inference_state.device
|
||||
storage_device = inference_state.storage_device
|
||||
|
||||
# Prepare batch inputs
|
||||
batch_size = 1
|
||||
@ -2495,21 +2585,21 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict=inference_state["output_dict_per_obj"][obj_idx],
|
||||
output_dict=inference_state.output_dict_per_obj[obj_idx],
|
||||
run_mem_encoder=False,
|
||||
reverse=False,
|
||||
)
|
||||
|
||||
# Update the output dictionary
|
||||
output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
|
||||
|
||||
if is_init_cond_frame:
|
||||
output_dict["cond_frame_outputs"][frame_idx] = current_out
|
||||
inference_state.temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"][frame_idx] = current_out
|
||||
else:
|
||||
output_dict["non_cond_frame_outputs"][frame_idx] = current_out
|
||||
inference_state.temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"][frame_idx] = current_out
|
||||
|
||||
# Resize the output mask to the original video resolution
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
obj_ids = inference_state.obj_ids
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
@ -2531,8 +2621,8 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||
# add them into "output_dict".
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
|
||||
obj_temp_output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
@ -2543,7 +2633,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
# Run memory encoder on the temporary outputs (if the memory feature is missing)
|
||||
if out["maskmem_features"] is None:
|
||||
high_res_masks = torch.nn.functional.interpolate(
|
||||
out["pred_masks"].to(inference_state["device"]),
|
||||
out["pred_masks"].to(inference_state.device),
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
@ -2566,7 +2656,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
obj_temp_output_dict[storage_key].clear()
|
||||
|
||||
# check and make sure that every object has received input points or masks
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
|
||||
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
||||
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
||||
raise RuntimeError(
|
||||
@ -2591,8 +2681,8 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
"""
|
||||
self.propagate_in_video_preflight(inference_state)
|
||||
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
num_frames = inference_state["num_frames"]
|
||||
obj_ids = inference_state.obj_ids
|
||||
num_frames = inference_state.num_frames
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
|
||||
# set start index, end index, and processing order
|
||||
@ -2600,7 +2690,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
# default: start from the earliest frame with input points
|
||||
start_frame_idx = min(
|
||||
t
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
||||
for obj_output_dict in inference_state.output_dict_per_obj.values()
|
||||
for t in obj_output_dict["cond_frame_outputs"]
|
||||
)
|
||||
if max_frame_num_to_track is None:
|
||||
@ -2619,7 +2709,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
||||
pred_masks_per_obj = [None] * batch_size
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
|
||||
# We skip those frames already in consolidated outputs (these are frames
|
||||
# that received input clicks or mask). Note that we cannot directly run
|
||||
# batched forward on them via `_run_single_frame_inference` because the
|
||||
@ -2627,7 +2717,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = obj_output_dict[storage_key][frame_idx]
|
||||
device = inference_state["device"]
|
||||
device = inference_state.device
|
||||
pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
@ -2644,7 +2734,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
)
|
||||
obj_output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {"reverse": reverse}
|
||||
inference_state.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}
|
||||
pred_masks_per_obj[obj_idx] = pred_masks
|
||||
|
||||
# Resize the output mask to the original video resolution (we directly use
|
||||
@ -2665,18 +2755,18 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
"""Prepare vision features for a frame."""
|
||||
|
||||
# Check if features are cached
|
||||
if frame_idx in inference_state["cached_features"]:
|
||||
cached = inference_state["cached_features"][frame_idx]
|
||||
if frame_idx in inference_state.cached_features:
|
||||
cached = inference_state.cached_features[frame_idx]
|
||||
vision_feats = cached["vision_feats"]
|
||||
vision_pos_embeds = cached["vision_pos_embeds"]
|
||||
else:
|
||||
# Compute features using image encoder
|
||||
image_batch = inference_state["images"][frame_idx].unsqueeze(0) # Add batch dimension
|
||||
image_batch = inference_state.images[frame_idx].unsqueeze(0) # Add batch dimension
|
||||
feature_maps, feature_maps_position_embeddings, _, _ = self.get_image_features(image_batch)
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in feature_maps_position_embeddings]
|
||||
# Cache features
|
||||
inference_state["cached_features"][frame_idx] = {
|
||||
inference_state.cached_features[frame_idx] = {
|
||||
"vision_feats": vision_feats,
|
||||
"vision_pos_embeds": vision_pos_embeds,
|
||||
}
|
||||
@ -2712,7 +2802,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
)
|
||||
|
||||
# optionally offload the output to CPU memory to save GPU space
|
||||
storage_device = inference_state["storage_device"]
|
||||
storage_device = inference_state.storage_device
|
||||
maskmem_features = maskmem_features.to(torch.bfloat16)
|
||||
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
@ -2724,7 +2814,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
||||
a constant in the inference session to reduce session storage size.
|
||||
"""
|
||||
model_constants = inference_state["constants"]
|
||||
model_constants = inference_state.constants
|
||||
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
||||
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||
if out_maskmem_pos_enc is not None:
|
||||
@ -2771,14 +2861,14 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict=output_dict,
|
||||
num_frames=inference_state["num_frames"],
|
||||
num_frames=inference_state.num_frames,
|
||||
track_in_reverse=reverse,
|
||||
run_mem_encoder=run_mem_encoder,
|
||||
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
# optionally offload the output to CPU memory to save GPU space
|
||||
storage_device = inference_state["storage_device"]
|
||||
storage_device = inference_state.storage_device
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
if maskmem_features is not None:
|
||||
maskmem_features = maskmem_features.to(torch.bfloat16)
|
||||
@ -3321,4 +3411,4 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
return pred_masks
|
||||
|
||||
|
||||
__all__ = ["Sam2Model", "Sam2PreTrainedModel"]
|
||||
__all__ = ["Sam2Model", "Sam2VideoSessionState", "Sam2PreTrainedModel"]
|
||||
|
@ -16,18 +16,16 @@
|
||||
Processor class for SAM2.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import Normalize, ToTensor
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import TensorType, is_tf_available, is_torch_available, logging
|
||||
from ...video_utils import VideoInput
|
||||
from .modeling_sam2 import Sam2VideoSessionState
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -36,7 +34,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
pass
|
||||
|
||||
|
||||
class Sam2Processor(ProcessorMixin):
|
||||
@ -52,17 +50,16 @@ class Sam2Processor(ProcessorMixin):
|
||||
An instance of [`Sam2ImageProcessor`]. The image processor is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor"]
|
||||
image_processor_class = "Sam2ImageProcessor"
|
||||
attributes = ["image_processor", "video_processor"]
|
||||
image_processor_class = "Sam2ImageProcessorFast"
|
||||
video_processor_class = "Sam2VideoProcessor"
|
||||
|
||||
def __init__(self, image_processor):
|
||||
super().__init__(image_processor)
|
||||
self.current_processor = self.image_processor
|
||||
self.point_pad_value = -10
|
||||
self.target_size = self.image_processor.size["longest_edge"]
|
||||
|
||||
# Video inference state
|
||||
self.inference_state = None
|
||||
def __init__(
|
||||
self, image_processor, video_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs
|
||||
):
|
||||
super().__init__(image_processor, video_processor, **kwargs)
|
||||
self.point_pad_value = point_pad_value
|
||||
self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -108,6 +105,15 @@ class Sam2Processor(ProcessorMixin):
|
||||
|
||||
return encoding_image_processor
|
||||
|
||||
def init_video_session(self, video: VideoInput):
|
||||
processed_video = self.video_processor(videos=video, return_tensors="pt").to("cuda")
|
||||
inference_state = Sam2VideoSessionState(
|
||||
processed_video.pixel_values_videos[0],
|
||||
video_height=processed_video.original_sizes[0][0],
|
||||
video_width=processed_video.original_sizes[0][1],
|
||||
)
|
||||
return inference_state
|
||||
|
||||
def _normalize_and_convert(
|
||||
self,
|
||||
encoding_image_processor,
|
||||
@ -155,30 +161,19 @@ class Sam2Processor(ProcessorMixin):
|
||||
input_boxes = torch.from_numpy(input_boxes)
|
||||
# boxes batch size of 1 by default
|
||||
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
|
||||
elif return_tensors == "tf":
|
||||
input_boxes = tf.convert_to_tensor(input_boxes)
|
||||
# boxes batch size of 1 by default
|
||||
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
|
||||
|
||||
encoding_image_processor.update({"input_boxes": input_boxes})
|
||||
if input_points is not None:
|
||||
if return_tensors == "pt":
|
||||
input_points = torch.from_numpy(input_points)
|
||||
# point batch size of 1 by default
|
||||
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
|
||||
elif return_tensors == "tf":
|
||||
input_points = tf.convert_to_tensor(input_points)
|
||||
# point batch size of 1 by default
|
||||
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
|
||||
encoding_image_processor.update({"input_points": input_points})
|
||||
if input_labels is not None:
|
||||
if return_tensors == "pt":
|
||||
input_labels = torch.from_numpy(input_labels)
|
||||
# point batch size of 1 by default
|
||||
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
|
||||
elif return_tensors == "tf":
|
||||
input_labels = tf.convert_to_tensor(input_labels)
|
||||
# point batch size of 1 by default
|
||||
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
|
||||
encoding_image_processor.update({"input_labels": input_labels})
|
||||
|
||||
return encoding_image_processor
|
||||
@ -267,172 +262,12 @@ class Sam2Processor(ProcessorMixin):
|
||||
|
||||
return input_points, input_labels, input_boxes
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(image_processor_input_names))
|
||||
|
||||
def post_process_masks(self, *args, **kwargs):
|
||||
return self.image_processor.post_process_masks(*args, **kwargs)
|
||||
|
||||
def init_state(
|
||||
self,
|
||||
video_path: Union[str, Path],
|
||||
offload_video_to_cpu: bool = False,
|
||||
offload_state_to_cpu: bool = False,
|
||||
async_loading_frames: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> None:
|
||||
"""Initialize video inference state."""
|
||||
if not is_torch_available():
|
||||
raise ImportError("Video inference requires PyTorch to be installed")
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load video frames
|
||||
images, video_height, video_width = self._load_video_frames(
|
||||
video_path=video_path,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Initialize inference state
|
||||
self.inference_state = {
|
||||
"images": images,
|
||||
"num_frames": len(images),
|
||||
"offload_video_to_cpu": offload_video_to_cpu,
|
||||
"offload_state_to_cpu": offload_state_to_cpu,
|
||||
"video_height": video_height,
|
||||
"video_width": video_width,
|
||||
"device": device,
|
||||
"storage_device": torch.device("cpu") if offload_state_to_cpu else device,
|
||||
# Input tracking
|
||||
"point_inputs_per_obj": {},
|
||||
"mask_inputs_per_obj": {},
|
||||
# Visual features cache
|
||||
"cached_features": {},
|
||||
"constants": {},
|
||||
# Object management
|
||||
"obj_id_to_idx": OrderedDict(),
|
||||
"obj_idx_to_id": OrderedDict(),
|
||||
"obj_ids": [],
|
||||
# Output tracking
|
||||
"output_dict_per_obj": {},
|
||||
"temp_output_dict_per_obj": {},
|
||||
"frames_tracked_per_obj": {},
|
||||
}
|
||||
|
||||
logger.info(f"Initialized video state with {len(images)} frames at resolution {video_height}x{video_width}")
|
||||
|
||||
def reset_state(self) -> None:
|
||||
"""Reset the video inference state."""
|
||||
if self.inference_state is not None:
|
||||
# Clear all state
|
||||
self.inference_state["point_inputs_per_obj"].clear()
|
||||
self.inference_state["mask_inputs_per_obj"].clear()
|
||||
self.inference_state["cached_features"].clear()
|
||||
self.inference_state["constants"].clear()
|
||||
self.inference_state["obj_id_to_idx"].clear()
|
||||
self.inference_state["obj_idx_to_id"].clear()
|
||||
self.inference_state["obj_ids"].clear()
|
||||
self.inference_state["output_dict_per_obj"].clear()
|
||||
self.inference_state["temp_output_dict_per_obj"].clear()
|
||||
self.inference_state["frames_tracked_per_obj"].clear()
|
||||
|
||||
self.inference_state = None
|
||||
logger.info("Reset video inference state")
|
||||
|
||||
def _load_video_frames(
|
||||
self,
|
||||
video_path: Union[str, Path],
|
||||
offload_video_to_cpu: bool = False,
|
||||
async_loading_frames: bool = False,
|
||||
device: torch.device = None,
|
||||
) -> tuple[list[torch.Tensor], int, int]:
|
||||
"""Load video frames from a directory of images."""
|
||||
video_path = Path(video_path)
|
||||
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video path {video_path} does not exist")
|
||||
|
||||
# Get image files
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
|
||||
image_files = [f for f in video_path.iterdir() if f.suffix.lower() in image_extensions]
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(f"No image files found in {video_path}")
|
||||
|
||||
# Sort files by name (assuming frame order)
|
||||
image_files.sort(key=lambda x: x.name)
|
||||
|
||||
# Load first image to get dimensions
|
||||
from PIL import Image
|
||||
|
||||
first_image = Image.open(image_files[0])
|
||||
video_width, video_height = first_image.size
|
||||
|
||||
# Process images using image processor
|
||||
images = []
|
||||
for img_path in image_files:
|
||||
image = Image.open(img_path)
|
||||
# Convert to RGB if needed
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Process image
|
||||
image = image.resize((1024, 1024))
|
||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||
to_tensor = ToTensor()
|
||||
transforms = torch.jit.script(
|
||||
nn.Sequential(
|
||||
Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
||||
)
|
||||
)
|
||||
# processed = self.image_processor(image, return_tensors="pt")
|
||||
# image_tensor = processed["pixel_values"].squeeze(0) # Remove batch dim
|
||||
image_tensor = transforms(to_tensor(image))
|
||||
if not offload_video_to_cpu and device is not None:
|
||||
image_tensor = image_tensor.to(device)
|
||||
|
||||
images.append(image_tensor)
|
||||
|
||||
return images, video_height, video_width
|
||||
|
||||
def _obj_id_to_idx(self, obj_id: int) -> int:
|
||||
"""Map client-side object id to model-side object index."""
|
||||
if self.inference_state is None:
|
||||
raise ValueError("Video state not initialized. Call init_state() first.")
|
||||
|
||||
obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# Add new object
|
||||
obj_idx = len(self.inference_state["obj_id_to_idx"])
|
||||
self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
||||
self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
||||
self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
|
||||
|
||||
# Set up input and output structures for this object
|
||||
self.inference_state["point_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.inference_state["frames_tracked_per_obj"][obj_idx] = {}
|
||||
|
||||
return obj_idx
|
||||
|
||||
def add_new_points_or_box(
|
||||
def process_new_points_or_box(
|
||||
self,
|
||||
inference_state: Sam2VideoSessionState,
|
||||
frame_idx: int,
|
||||
obj_id: int,
|
||||
points: Optional[list[list[float]]] = None,
|
||||
@ -442,15 +277,9 @@ class Sam2Processor(ProcessorMixin):
|
||||
box: Optional[list[float]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Add new points or box to a frame and return preprocessed inputs for model."""
|
||||
if self.inference_state is None:
|
||||
raise ValueError("Video state not initialized. Call init_state() first.")
|
||||
|
||||
if not is_torch_available():
|
||||
raise ImportError("Video inference requires PyTorch to be installed")
|
||||
|
||||
obj_idx = self._obj_id_to_idx(obj_id)
|
||||
point_inputs_per_frame = self.inference_state["point_inputs_per_obj"][obj_idx]
|
||||
mask_inputs_per_frame = self.inference_state["mask_inputs_per_obj"][obj_idx]
|
||||
obj_idx = inference_state._obj_id_to_idx(obj_id)
|
||||
point_inputs_per_frame = inference_state.point_inputs_per_obj[obj_idx]
|
||||
mask_inputs_per_frame = inference_state.mask_inputs_per_obj[obj_idx]
|
||||
|
||||
# Validate inputs
|
||||
if (points is not None) != (labels is not None):
|
||||
@ -458,7 +287,7 @@ class Sam2Processor(ProcessorMixin):
|
||||
if points is None and box is None:
|
||||
raise ValueError("at least one of points or box must be provided as input")
|
||||
|
||||
device = self.inference_state["device"]
|
||||
device = inference_state.device
|
||||
|
||||
# Process points
|
||||
if points is None:
|
||||
@ -496,8 +325,8 @@ class Sam2Processor(ProcessorMixin):
|
||||
|
||||
# Normalize coordinates
|
||||
if normalize_coords:
|
||||
video_H = self.inference_state["video_height"]
|
||||
video_W = self.inference_state["video_width"]
|
||||
video_H = inference_state.video_height
|
||||
video_W = inference_state.video_width
|
||||
points = points / torch.tensor([video_W, video_H]).to(points.device)
|
||||
|
||||
# Scale by model's internal image size
|
||||
@ -523,7 +352,7 @@ class Sam2Processor(ProcessorMixin):
|
||||
mask_inputs_per_frame.pop(frame_idx, None) # Clear any mask inputs
|
||||
|
||||
# Determine frame type and tracking direction
|
||||
obj_frames_tracked = self.inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
obj_frames_tracked = inference_state.frames_tracked_per_obj[obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
|
||||
if is_init_cond_frame:
|
||||
@ -544,22 +373,17 @@ class Sam2Processor(ProcessorMixin):
|
||||
|
||||
def add_new_mask(
|
||||
self,
|
||||
inference_state: Sam2VideoSessionState,
|
||||
frame_idx: int,
|
||||
obj_id: int,
|
||||
mask: Union[np.ndarray, torch.Tensor],
|
||||
) -> dict[str, Any]:
|
||||
"""Add new mask to a frame and return preprocessed inputs for model."""
|
||||
if self.inference_state is None:
|
||||
raise ValueError("Video state not initialized. Call init_state() first.")
|
||||
obj_idx = inference_state._obj_id_to_idx(obj_id)
|
||||
point_inputs_per_frame = inference_state.point_inputs_per_obj[obj_idx]
|
||||
mask_inputs_per_frame = inference_state.mask_inputs_per_obj[obj_idx]
|
||||
|
||||
if not is_torch_available():
|
||||
raise ImportError("Video inference requires PyTorch to be installed")
|
||||
|
||||
obj_idx = self._obj_id_to_idx(obj_id)
|
||||
point_inputs_per_frame = self.inference_state["point_inputs_per_obj"][obj_idx]
|
||||
mask_inputs_per_frame = self.inference_state["mask_inputs_per_obj"][obj_idx]
|
||||
|
||||
device = self.inference_state["device"]
|
||||
device = inference_state.device
|
||||
|
||||
# Process mask
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
@ -586,7 +410,7 @@ class Sam2Processor(ProcessorMixin):
|
||||
point_inputs_per_frame.pop(frame_idx, None) # Clear any point inputs
|
||||
|
||||
# Determine frame type and tracking direction
|
||||
obj_frames_tracked = self.inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
obj_frames_tracked = inference_state.frames_tracked_per_obj[obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
|
||||
if is_init_cond_frame:
|
||||
|
117
src/transformers/models/sam2/video_processing_sam2.py
Normal file
117
src/transformers/models/sam2/video_processing_sam2.py
Normal file
@ -0,0 +1,117 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SAM2."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
)
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn import functional as F_t
|
||||
|
||||
|
||||
class Sam2VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_DEFAULT_MEAN
|
||||
image_std = IMAGENET_DEFAULT_STD
|
||||
size = {"height": 1024, "width": 1024}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list["torch.Tensor"],
|
||||
size: Optional[SizeDict],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
original_sizes = [video.shape[-2:] for video in videos]
|
||||
reshaped_input_sizes = [(size.height, size.width) for _ in range(len(videos))]
|
||||
batch_feature = super()._preprocess(videos, size=size, return_tensors=return_tensors, **kwargs)
|
||||
batch_feature = BatchFeature(
|
||||
data={
|
||||
"original_sizes": original_sizes,
|
||||
"reshaped_input_sizes": reshaped_input_sizes,
|
||||
**batch_feature.data,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
return batch_feature
|
||||
|
||||
def post_process_masks(
|
||||
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||
):
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
||||
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||||
width) format.
|
||||
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The threshold to use for binarizing the masks.
|
||||
binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the masks.
|
||||
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||
assumed to be the processor's `pad_size`.
|
||||
Returns:
|
||||
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
||||
is given by original_size.
|
||||
"""
|
||||
pad_size = self.size if pad_size is None else pad_size
|
||||
target_image_size = (pad_size["height"], pad_size["width"])
|
||||
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
||||
original_sizes = original_sizes.tolist()
|
||||
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
||||
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
||||
output_masks = []
|
||||
for i, original_size in enumerate(original_sizes):
|
||||
if isinstance(masks[i], np.ndarray):
|
||||
masks[i] = torch.from_numpy(masks[i])
|
||||
elif not isinstance(masks[i], torch.Tensor):
|
||||
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
||||
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
||||
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
||||
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
||||
if binarize:
|
||||
interpolated_mask = interpolated_mask > mask_threshold
|
||||
output_masks.append(interpolated_mask)
|
||||
|
||||
return output_masks
|
||||
|
||||
|
||||
__all__ = ["Sam2VideoProcessor"]
|
@ -19,9 +19,17 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import Sam2Config, Sam2ImageEncoderConfig, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, pipeline
|
||||
from transformers import (
|
||||
Sam2Config,
|
||||
Sam2ImageEncoderConfig,
|
||||
Sam2MaskDecoderConfig,
|
||||
Sam2Processor,
|
||||
Sam2PromptEncoderConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.video_utils import load_video
|
||||
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@ -417,19 +425,32 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
def prepare_image():
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_dog_img():
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam2.png"
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_video():
|
||||
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
|
||||
raw_video, _ = load_video(video_url)
|
||||
return raw_video
|
||||
|
||||
|
||||
@slow
|
||||
class Sam2ModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
|
||||
self.processor = Sam2Processor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
|
||||
self.model.to(torch_device)
|
||||
self.model.eval()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
@ -437,45 +458,98 @@ class Sam2ModelIntegrationTest(unittest.TestCase):
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_inference_mask_generation_no_point(self):
|
||||
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
|
||||
pass
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
|
||||
|
||||
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
|
||||
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
|
||||
# raw_image = prepare_image()
|
||||
# inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# with torch.no_grad():
|
||||
# outputs = model(**inputs)
|
||||
# scores = outputs.iou_scores.squeeze()
|
||||
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
|
||||
# self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point_multimask(self):
|
||||
raw_image = prepare_image()
|
||||
inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
|
||||
input_points = [[[[500, 375]]]]
|
||||
input_labels = [[[1]]]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
|
||||
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
raw_image = prepare_image()
|
||||
input_boxes = [[[650, 900, 1000, 1250]]]
|
||||
input_points = [[[820, 1080]]]
|
||||
|
||||
inputs = processor(
|
||||
images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
|
||||
inputs = self.processor(
|
||||
images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
|
||||
).to(torch_device)
|
||||
# to_tensor = ToTensor()
|
||||
# transforms = torch.jit.script(
|
||||
# nn.Sequential(
|
||||
# Resize((1024, 1024)),
|
||||
# Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
# )
|
||||
# )
|
||||
# inputs["pixel_values"] = transforms(to_tensor(raw_image)).unsqueeze(0).to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
|
||||
self.assertTrue(
|
||||
torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
|
||||
outputs = self.model(**inputs)
|
||||
self.assertEqual(outputs.ious.shape, (1, 1, 3))
|
||||
self.assertEqual(outputs.low_res_masks.shape, (1, 1, 3, 256, 256))
|
||||
sorted_indices = torch.argsort(outputs.ious.squeeze(), descending=True)
|
||||
scores = outputs.ious.squeeze()[sorted_indices]
|
||||
masks_logits = outputs.low_res_masks.squeeze()[sorted_indices][0, :3, :3]
|
||||
print("scores", scores)
|
||||
print("masks_logits", masks_logits)
|
||||
torch.testing.assert_close(
|
||||
scores, torch.tensor([0.9546, 0.4937, 0.0428]).to(torch_device), atol=1e-4, rtol=1e-4
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
masks_logits,
|
||||
torch.tensor(
|
||||
[[-25.0963, -41.5728, -30.8723], [-34.7112, -30.7988, -36.4013], [-25.3061, -37.4575, -33.1899]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_one_point(self):
|
||||
pass
|
||||
# raw_video = prepare_video()
|
||||
# self.processor.init_state(video_path="./videos/bedroom_light")
|
||||
|
||||
# inputs = processor.add_new_points_or_box(
|
||||
# frame_idx=0,
|
||||
# obj_id=1,
|
||||
# points=[[[[210, 350]]]],
|
||||
# labels=[[[1]]],
|
||||
# )
|
||||
|
||||
# def test_inference_mask_generation_one_point_one_bb(self):
|
||||
# model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
|
||||
# processor = SamProcessor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
|
||||
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
|
||||
# raw_image = prepare_image()
|
||||
# input_boxes = [[[[650, 900, 1000, 1250]]]]
|
||||
# input_points = [[[[820, 1080]]]]
|
||||
|
||||
# inputs = processor(
|
||||
# images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
|
||||
# ).to(torch_device)
|
||||
|
||||
# with torch.no_grad():
|
||||
# outputs = model(**inputs)
|
||||
# scores = outputs.iou_scores.squeeze()
|
||||
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
|
||||
# self.assertTrue(
|
||||
# torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
|
||||
# )
|
||||
|
||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
|
||||
|
Loading…
Reference in New Issue
Block a user