add Sam2VideoSessionState + fast image proc + video proc

This commit is contained in:
yonigozlan 2025-06-25 20:45:16 +00:00
parent 79055ad65f
commit c3330c677a
9 changed files with 547 additions and 290 deletions

View File

@ -142,7 +142,7 @@ else:
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor",)),
("sam2", ("Sam2ImageProcessor",)),
("sam2", ("Sam2ImageProcessor", "Sam2ImageProcessorFast")),
("sam_hq", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),

View File

@ -54,6 +54,7 @@ else:
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
("qwen2_vl", "Qwen2VLVideoProcessor"),
("sam2", "Sam2VideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"),
("vjepa2", "VJEPA2VideoProcessor"),

View File

@ -20,8 +20,10 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_sam2 import *
from .image_processing_sam2 import *
from .image_processing_sam2_fast import *
from .modeling_sam2 import *
from .processing_sam2 import *
from .video_processing_sam2 import *
else:
import sys

View File

@ -30,13 +30,14 @@ from PIL import Image
from transformers import (
Sam2Config,
Sam2ImageEncoderConfig,
Sam2ImageProcessor,
Sam2ImageProcessorFast,
Sam2MaskDecoderConfig,
Sam2MemoryAttentionConfig,
Sam2MemoryEncoderConfig,
Sam2Model,
Sam2Processor,
Sam2PromptEncoderConfig,
Sam2VideoProcessor,
)
@ -54,13 +55,28 @@ def get_config(model_name):
memory_attention_config = Sam2MemoryAttentionConfig()
memory_encoder_config = Sam2MemoryEncoderConfig()
elif "sam2.1_hiera_base_plus" in model_name:
image_encoder_config = Sam2ImageEncoderConfig(hidden_size=112, num_heads=2, stages=(2, 3, 16, 3), global_attention_blocks=(12, 16, 20), window_positional_embedding_background_size=(14, 14), backbone_channel_list=[896, 448, 224, 112])
image_encoder_config = Sam2ImageEncoderConfig(
hidden_size=112,
num_heads=2,
stages=(2, 3, 16, 3),
global_attention_blocks=(12, 16, 20),
window_positional_embedding_background_size=(14, 14),
backbone_channel_list=[896, 448, 224, 112],
)
prompt_encoder_config = Sam2PromptEncoderConfig()
mask_decoder_config = Sam2MaskDecoderConfig()
memory_attention_config = Sam2MemoryAttentionConfig()
memory_encoder_config = Sam2MemoryEncoderConfig()
elif "sam2.1_hiera_large" in model_name:
image_encoder_config = Sam2ImageEncoderConfig(hidden_size=144, num_heads=2, stages=(2, 6, 36, 4), global_attention_blocks=(23, 33, 43), window_positional_embedding_background_size=(7, 7), window_spec=(8, 4, 16, 8), backbone_channel_list=[1152, 576, 288, 144])
image_encoder_config = Sam2ImageEncoderConfig(
hidden_size=144,
num_heads=2,
stages=(2, 6, 36, 4),
global_attention_blocks=(23, 33, 43),
window_positional_embedding_background_size=(7, 7),
window_spec=(8, 4, 16, 8),
backbone_channel_list=[1152, 576, 288, 144],
)
prompt_encoder_config = Sam2PromptEncoderConfig()
mask_decoder_config = Sam2MaskDecoderConfig()
memory_attention_config = Sam2MemoryAttentionConfig()
@ -197,8 +213,9 @@ def convert_sam2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, pu
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
state_dict = replace_keys(state_dict)
image_processor = Sam2ImageProcessor()
processor = Sam2Processor(image_processor=image_processor)
image_processor = Sam2ImageProcessorFast()
video_processor = Sam2VideoProcessor()
processor = Sam2Processor(image_processor=image_processor, video_processor=video_processor)
hf_model = Sam2Model(config)
hf_model.eval()
@ -292,6 +309,10 @@ if __name__ == "__main__":
args = parser.parse_args()
hf_model_name = args.model_name.replace("_", "-")
checkpoint_path = hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name}.pt") if args.checkpoint_path is None else args.checkpoint_path
checkpoint_path = (
hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name}.pt")
if args.checkpoint_path is None
else args.checkpoint_path
)
convert_sam2_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for SAM2."""
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
)
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
PILImageResampling,
SizeDict,
)
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
)
if is_torch_available():
import torch
from torch.nn import functional as F_t
class Sam2ImageProcessorFastKwargs(DefaultFastImageProcessorKwargs):
do_pad: bool
mask_pad_size: SizeDict
@auto_docstring
class Sam2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"height": 1024, "width": 1024}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = False
def _preprocess(
self,
images: list["torch.Tensor"],
size: Optional[SizeDict],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
original_sizes = [image.shape[-2:] for image in images]
reshaped_input_sizes = [(size.height, size.width) for _ in range(len(images))]
batch_feature = super()._preprocess(images, size=size, return_tensors=return_tensors, **kwargs)
batch_feature = BatchFeature(
data={
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
**batch_feature.data,
},
tensor_type=return_tensors,
)
return batch_feature
def post_process_masks(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
is given by original_size.
"""
pad_size = self.size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
original_sizes = original_sizes.tolist()
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
reshaped_input_sizes = reshaped_input_sizes.tolist()
output_masks = []
for i, original_size in enumerate(original_sizes):
if isinstance(masks[i], np.ndarray):
masks[i] = torch.from_numpy(masks[i])
elif not isinstance(masks[i], torch.Tensor):
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(interpolated_mask)
return output_masks
__all__ = ["Sam2ImageProcessorFast"]

View File

@ -19,6 +19,7 @@ import collections.abc
import copy
import math
import warnings
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterator, Optional, Union
@ -125,6 +126,95 @@ def fill_holes_in_mask_scores(mask, max_area):
return mask
class Sam2VideoSessionState:
images: torch.FloatTensor = None
num_frames: int = None
offload_video_to_cpu: bool = None
offload_state_to_cpu: bool = None
video_height: int = None
video_width: int = None
device: torch.device = None
storage_device: torch.device = None
point_inputs_per_obj: dict = None
mask_inputs_per_obj: dict = None
cached_features: dict = None
constants: dict = None
obj_id_to_idx: dict = None
obj_idx_to_id: dict = None
obj_ids: list = None
output_dict_per_obj: dict = None
temp_output_dict_per_obj: dict = None
frames_tracked_per_obj: dict = None
# TODO add async video loading?
def __init__(
self,
video: torch.FloatTensor,
video_height: int,
video_width: int,
offload_video_to_cpu: bool = False,
offload_state_to_cpu: bool = False,
async_loading_frames: bool = False,
):
self.images = list(video)
self.num_frames = len(video)
self.offload_video_to_cpu = offload_video_to_cpu
self.offload_state_to_cpu = offload_state_to_cpu
self.async_loading_frames = async_loading_frames
self.video_height = video_height
self.video_width = video_width
self.device = video.device
self.storage_device = torch.device("cpu") if offload_state_to_cpu else video.device
self.cached_features = {}
self.point_inputs_per_obj = {}
self.mask_inputs_per_obj = {}
self.constants = {}
self.obj_id_to_idx = OrderedDict()
self.obj_idx_to_id = OrderedDict()
self.obj_ids = []
self.output_dict_per_obj = {}
self.temp_output_dict_per_obj = {}
self.frames_tracked_per_obj = {}
def reset_inference_session(self):
self.point_inputs_per_obj.clear()
self.mask_inputs_per_obj.clear()
self.constants.clear()
self.obj_id_to_idx.clear()
self.obj_idx_to_id.clear()
self.obj_ids.clear()
self.output_dict_per_obj.clear()
self.temp_output_dict_per_obj.clear()
self.frames_tracked_per_obj.clear()
def _obj_id_to_idx(self, obj_id: int) -> int:
"""Map client-side object id to model-side object index."""
obj_idx = self.obj_id_to_idx.get(obj_id, None)
if obj_idx is not None:
return obj_idx
# Add new object
obj_idx = len(self.obj_id_to_idx)
self.obj_id_to_idx[obj_id] = obj_idx
self.obj_idx_to_id[obj_idx] = obj_id
self.obj_ids = list(self.obj_id_to_idx)
# Set up input and output structures for this object
self.point_inputs_per_obj[obj_idx] = {}
self.mask_inputs_per_obj[obj_idx] = {}
self.output_dict_per_obj[obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
self.temp_output_dict_per_obj[obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
self.frames_tracked_per_obj[obj_idx] = {}
return obj_idx
@dataclass
class Sam2ImageEncoderOutput(ModelOutput):
"""
@ -2366,20 +2456,20 @@ class Sam2Model(Sam2PreTrainedModel):
# Video Inference specific functions
def _obj_idx_to_id(self, inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
return inference_state.obj_idx_to_id[obj_idx]
def _get_obj_num(self, inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
return len(inference_state.obj_idx_to_id)
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
device = inference_state.device
video_H = inference_state.video_height
video_W = inference_state.video_width
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
@ -2415,8 +2505,8 @@ class Sam2Model(Sam2PreTrainedModel):
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_H = inference_state.video_height
consolidated_W = inference_state.video_width
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
@ -2431,12 +2521,12 @@ class Sam2Model(Sam2PreTrainedModel):
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
device=inference_state.storage_device,
),
}
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
@ -2481,8 +2571,8 @@ class Sam2Model(Sam2PreTrainedModel):
"""
Add new conditioning inputs to a frame and run inference.
"""
device = inference_state["device"]
storage_device = inference_state["storage_device"]
device = inference_state.device
storage_device = inference_state.storage_device
# Prepare batch inputs
batch_size = 1
@ -2495,21 +2585,21 @@ class Sam2Model(Sam2PreTrainedModel):
is_init_cond_frame=is_init_cond_frame,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=inference_state["output_dict_per_obj"][obj_idx],
output_dict=inference_state.output_dict_per_obj[obj_idx],
run_mem_encoder=False,
reverse=False,
)
# Update the output dictionary
output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
if is_init_cond_frame:
output_dict["cond_frame_outputs"][frame_idx] = current_out
inference_state.temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"][frame_idx] = current_out
else:
output_dict["non_cond_frame_outputs"][frame_idx] = current_out
inference_state.temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
obj_ids = inference_state.obj_ids
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
@ -2531,8 +2621,8 @@ class Sam2Model(Sam2PreTrainedModel):
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
obj_temp_output_dict = inference_state.temp_output_dict_per_obj[obj_idx]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
@ -2543,7 +2633,7 @@ class Sam2Model(Sam2PreTrainedModel):
# Run memory encoder on the temporary outputs (if the memory feature is missing)
if out["maskmem_features"] is None:
high_res_masks = torch.nn.functional.interpolate(
out["pred_masks"].to(inference_state["device"]),
out["pred_masks"].to(inference_state.device),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
@ -2566,7 +2656,7 @@ class Sam2Model(Sam2PreTrainedModel):
obj_temp_output_dict[storage_key].clear()
# check and make sure that every object has received input points or masks
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
if len(obj_output_dict["cond_frame_outputs"]) == 0:
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
raise RuntimeError(
@ -2591,8 +2681,8 @@ class Sam2Model(Sam2PreTrainedModel):
"""
self.propagate_in_video_preflight(inference_state)
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
obj_ids = inference_state.obj_ids
num_frames = inference_state.num_frames
batch_size = self._get_obj_num(inference_state)
# set start index, end index, and processing order
@ -2600,7 +2690,7 @@ class Sam2Model(Sam2PreTrainedModel):
# default: start from the earliest frame with input points
start_frame_idx = min(
t
for obj_output_dict in inference_state["output_dict_per_obj"].values()
for obj_output_dict in inference_state.output_dict_per_obj.values()
for t in obj_output_dict["cond_frame_outputs"]
)
if max_frame_num_to_track is None:
@ -2619,7 +2709,7 @@ class Sam2Model(Sam2PreTrainedModel):
for frame_idx in tqdm(processing_order, desc="propagate in video"):
pred_masks_per_obj = [None] * batch_size
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state.output_dict_per_obj[obj_idx]
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
@ -2627,7 +2717,7 @@ class Sam2Model(Sam2PreTrainedModel):
if frame_idx in obj_output_dict["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = obj_output_dict[storage_key][frame_idx]
device = inference_state["device"]
device = inference_state.device
pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
else:
storage_key = "non_cond_frame_outputs"
@ -2644,7 +2734,7 @@ class Sam2Model(Sam2PreTrainedModel):
)
obj_output_dict[storage_key][frame_idx] = current_out
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {"reverse": reverse}
inference_state.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}
pred_masks_per_obj[obj_idx] = pred_masks
# Resize the output mask to the original video resolution (we directly use
@ -2665,18 +2755,18 @@ class Sam2Model(Sam2PreTrainedModel):
"""Prepare vision features for a frame."""
# Check if features are cached
if frame_idx in inference_state["cached_features"]:
cached = inference_state["cached_features"][frame_idx]
if frame_idx in inference_state.cached_features:
cached = inference_state.cached_features[frame_idx]
vision_feats = cached["vision_feats"]
vision_pos_embeds = cached["vision_pos_embeds"]
else:
# Compute features using image encoder
image_batch = inference_state["images"][frame_idx].unsqueeze(0) # Add batch dimension
image_batch = inference_state.images[frame_idx].unsqueeze(0) # Add batch dimension
feature_maps, feature_maps_position_embeddings, _, _ = self.get_image_features(image_batch)
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in feature_maps_position_embeddings]
# Cache features
inference_state["cached_features"][frame_idx] = {
inference_state.cached_features[frame_idx] = {
"vision_feats": vision_feats,
"vision_pos_embeds": vision_pos_embeds,
}
@ -2712,7 +2802,7 @@ class Sam2Model(Sam2PreTrainedModel):
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
storage_device = inference_state.storage_device
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
@ -2724,7 +2814,7 @@ class Sam2Model(Sam2PreTrainedModel):
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
model_constants = inference_state.constants
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
@ -2771,14 +2861,14 @@ class Sam2Model(Sam2PreTrainedModel):
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
num_frames=inference_state.num_frames,
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
storage_device = inference_state.storage_device
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
@ -3321,4 +3411,4 @@ class Sam2Model(Sam2PreTrainedModel):
return pred_masks
__all__ = ["Sam2Model", "Sam2PreTrainedModel"]
__all__ = ["Sam2Model", "Sam2VideoSessionState", "Sam2PreTrainedModel"]

View File

@ -16,18 +16,16 @@
Processor class for SAM2.
"""
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
import torch.nn as nn
from torchvision.transforms import Normalize, ToTensor
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available, logging
from ...video_utils import VideoInput
from .modeling_sam2 import Sam2VideoSessionState
logger = logging.get_logger(__name__)
@ -36,7 +34,7 @@ if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
pass
class Sam2Processor(ProcessorMixin):
@ -52,17 +50,16 @@ class Sam2Processor(ProcessorMixin):
An instance of [`Sam2ImageProcessor`]. The image processor is a required input.
"""
attributes = ["image_processor"]
image_processor_class = "Sam2ImageProcessor"
attributes = ["image_processor", "video_processor"]
image_processor_class = "Sam2ImageProcessorFast"
video_processor_class = "Sam2VideoProcessor"
def __init__(self, image_processor):
super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]
# Video inference state
self.inference_state = None
def __init__(
self, image_processor, video_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs
):
super().__init__(image_processor, video_processor, **kwargs)
self.point_pad_value = point_pad_value
self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
def __call__(
self,
@ -108,6 +105,15 @@ class Sam2Processor(ProcessorMixin):
return encoding_image_processor
def init_video_session(self, video: VideoInput):
processed_video = self.video_processor(videos=video, return_tensors="pt").to("cuda")
inference_state = Sam2VideoSessionState(
processed_video.pixel_values_videos[0],
video_height=processed_video.original_sizes[0][0],
video_width=processed_video.original_sizes[0][1],
)
return inference_state
def _normalize_and_convert(
self,
encoding_image_processor,
@ -155,30 +161,19 @@ class Sam2Processor(ProcessorMixin):
input_boxes = torch.from_numpy(input_boxes)
# boxes batch size of 1 by default
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
elif return_tensors == "tf":
input_boxes = tf.convert_to_tensor(input_boxes)
# boxes batch size of 1 by default
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
encoding_image_processor.update({"input_boxes": input_boxes})
if input_points is not None:
if return_tensors == "pt":
input_points = torch.from_numpy(input_points)
# point batch size of 1 by default
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
elif return_tensors == "tf":
input_points = tf.convert_to_tensor(input_points)
# point batch size of 1 by default
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
encoding_image_processor.update({"input_points": input_points})
if input_labels is not None:
if return_tensors == "pt":
input_labels = torch.from_numpy(input_labels)
# point batch size of 1 by default
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
elif return_tensors == "tf":
input_labels = tf.convert_to_tensor(input_labels)
# point batch size of 1 by default
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
encoding_image_processor.update({"input_labels": input_labels})
return encoding_image_processor
@ -267,172 +262,12 @@ class Sam2Processor(ProcessorMixin):
return input_points, input_labels, input_boxes
@property
def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names))
def post_process_masks(self, *args, **kwargs):
return self.image_processor.post_process_masks(*args, **kwargs)
def init_state(
self,
video_path: Union[str, Path],
offload_video_to_cpu: bool = False,
offload_state_to_cpu: bool = False,
async_loading_frames: bool = False,
device: Optional[torch.device] = None,
) -> None:
"""Initialize video inference state."""
if not is_torch_available():
raise ImportError("Video inference requires PyTorch to be installed")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load video frames
images, video_height, video_width = self._load_video_frames(
video_path=video_path,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
device=device,
)
# Initialize inference state
self.inference_state = {
"images": images,
"num_frames": len(images),
"offload_video_to_cpu": offload_video_to_cpu,
"offload_state_to_cpu": offload_state_to_cpu,
"video_height": video_height,
"video_width": video_width,
"device": device,
"storage_device": torch.device("cpu") if offload_state_to_cpu else device,
# Input tracking
"point_inputs_per_obj": {},
"mask_inputs_per_obj": {},
# Visual features cache
"cached_features": {},
"constants": {},
# Object management
"obj_id_to_idx": OrderedDict(),
"obj_idx_to_id": OrderedDict(),
"obj_ids": [],
# Output tracking
"output_dict_per_obj": {},
"temp_output_dict_per_obj": {},
"frames_tracked_per_obj": {},
}
logger.info(f"Initialized video state with {len(images)} frames at resolution {video_height}x{video_width}")
def reset_state(self) -> None:
"""Reset the video inference state."""
if self.inference_state is not None:
# Clear all state
self.inference_state["point_inputs_per_obj"].clear()
self.inference_state["mask_inputs_per_obj"].clear()
self.inference_state["cached_features"].clear()
self.inference_state["constants"].clear()
self.inference_state["obj_id_to_idx"].clear()
self.inference_state["obj_idx_to_id"].clear()
self.inference_state["obj_ids"].clear()
self.inference_state["output_dict_per_obj"].clear()
self.inference_state["temp_output_dict_per_obj"].clear()
self.inference_state["frames_tracked_per_obj"].clear()
self.inference_state = None
logger.info("Reset video inference state")
def _load_video_frames(
self,
video_path: Union[str, Path],
offload_video_to_cpu: bool = False,
async_loading_frames: bool = False,
device: torch.device = None,
) -> tuple[list[torch.Tensor], int, int]:
"""Load video frames from a directory of images."""
video_path = Path(video_path)
if not video_path.exists():
raise ValueError(f"Video path {video_path} does not exist")
# Get image files
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
image_files = [f for f in video_path.iterdir() if f.suffix.lower() in image_extensions]
if not image_files:
raise ValueError(f"No image files found in {video_path}")
# Sort files by name (assuming frame order)
image_files.sort(key=lambda x: x.name)
# Load first image to get dimensions
from PIL import Image
first_image = Image.open(image_files[0])
video_width, video_height = first_image.size
# Process images using image processor
images = []
for img_path in image_files:
image = Image.open(img_path)
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Process image
image = image.resize((1024, 1024))
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
to_tensor = ToTensor()
transforms = torch.jit.script(
nn.Sequential(
Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
)
)
# processed = self.image_processor(image, return_tensors="pt")
# image_tensor = processed["pixel_values"].squeeze(0) # Remove batch dim
image_tensor = transforms(to_tensor(image))
if not offload_video_to_cpu and device is not None:
image_tensor = image_tensor.to(device)
images.append(image_tensor)
return images, video_height, video_width
def _obj_id_to_idx(self, obj_id: int) -> int:
"""Map client-side object id to model-side object index."""
if self.inference_state is None:
raise ValueError("Video state not initialized. Call init_state() first.")
obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# Add new object
obj_idx = len(self.inference_state["obj_id_to_idx"])
self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
# Set up input and output structures for this object
self.inference_state["point_inputs_per_obj"][obj_idx] = {}
self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
self.inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
self.inference_state["frames_tracked_per_obj"][obj_idx] = {}
return obj_idx
def add_new_points_or_box(
def process_new_points_or_box(
self,
inference_state: Sam2VideoSessionState,
frame_idx: int,
obj_id: int,
points: Optional[list[list[float]]] = None,
@ -442,15 +277,9 @@ class Sam2Processor(ProcessorMixin):
box: Optional[list[float]] = None,
) -> dict[str, Any]:
"""Add new points or box to a frame and return preprocessed inputs for model."""
if self.inference_state is None:
raise ValueError("Video state not initialized. Call init_state() first.")
if not is_torch_available():
raise ImportError("Video inference requires PyTorch to be installed")
obj_idx = self._obj_id_to_idx(obj_id)
point_inputs_per_frame = self.inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = self.inference_state["mask_inputs_per_obj"][obj_idx]
obj_idx = inference_state._obj_id_to_idx(obj_id)
point_inputs_per_frame = inference_state.point_inputs_per_obj[obj_idx]
mask_inputs_per_frame = inference_state.mask_inputs_per_obj[obj_idx]
# Validate inputs
if (points is not None) != (labels is not None):
@ -458,7 +287,7 @@ class Sam2Processor(ProcessorMixin):
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")
device = self.inference_state["device"]
device = inference_state.device
# Process points
if points is None:
@ -496,8 +325,8 @@ class Sam2Processor(ProcessorMixin):
# Normalize coordinates
if normalize_coords:
video_H = self.inference_state["video_height"]
video_W = self.inference_state["video_width"]
video_H = inference_state.video_height
video_W = inference_state.video_width
points = points / torch.tensor([video_W, video_H]).to(points.device)
# Scale by model's internal image size
@ -523,7 +352,7 @@ class Sam2Processor(ProcessorMixin):
mask_inputs_per_frame.pop(frame_idx, None) # Clear any mask inputs
# Determine frame type and tracking direction
obj_frames_tracked = self.inference_state["frames_tracked_per_obj"][obj_idx]
obj_frames_tracked = inference_state.frames_tracked_per_obj[obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
if is_init_cond_frame:
@ -544,22 +373,17 @@ class Sam2Processor(ProcessorMixin):
def add_new_mask(
self,
inference_state: Sam2VideoSessionState,
frame_idx: int,
obj_id: int,
mask: Union[np.ndarray, torch.Tensor],
) -> dict[str, Any]:
"""Add new mask to a frame and return preprocessed inputs for model."""
if self.inference_state is None:
raise ValueError("Video state not initialized. Call init_state() first.")
obj_idx = inference_state._obj_id_to_idx(obj_id)
point_inputs_per_frame = inference_state.point_inputs_per_obj[obj_idx]
mask_inputs_per_frame = inference_state.mask_inputs_per_obj[obj_idx]
if not is_torch_available():
raise ImportError("Video inference requires PyTorch to be installed")
obj_idx = self._obj_id_to_idx(obj_id)
point_inputs_per_frame = self.inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = self.inference_state["mask_inputs_per_obj"][obj_idx]
device = self.inference_state["device"]
device = inference_state.device
# Process mask
if not isinstance(mask, torch.Tensor):
@ -586,7 +410,7 @@ class Sam2Processor(ProcessorMixin):
point_inputs_per_frame.pop(frame_idx, None) # Clear any point inputs
# Determine frame type and tracking direction
obj_frames_tracked = self.inference_state["frames_tracked_per_obj"][obj_idx]
obj_frames_tracked = inference_state.frames_tracked_per_obj[obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
if is_init_cond_frame:

View File

@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for SAM2."""
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
PILImageResampling,
SizeDict,
)
from ...utils import (
TensorType,
is_torch_available,
)
from ...video_processing_utils import BaseVideoProcessor
if is_torch_available():
import torch
from torch.nn import functional as F_t
class Sam2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"height": 1024, "width": 1024}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
def _preprocess(
self,
videos: list["torch.Tensor"],
size: Optional[SizeDict],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
original_sizes = [video.shape[-2:] for video in videos]
reshaped_input_sizes = [(size.height, size.width) for _ in range(len(videos))]
batch_feature = super()._preprocess(videos, size=size, return_tensors=return_tensors, **kwargs)
batch_feature = BatchFeature(
data={
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
**batch_feature.data,
},
tensor_type=return_tensors,
)
return batch_feature
def post_process_masks(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
is given by original_size.
"""
pad_size = self.size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
original_sizes = original_sizes.tolist()
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
reshaped_input_sizes = reshaped_input_sizes.tolist()
output_masks = []
for i, original_size in enumerate(original_sizes):
if isinstance(masks[i], np.ndarray):
masks[i] = torch.from_numpy(masks[i])
elif not isinstance(masks[i], torch.Tensor):
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(interpolated_mask)
return output_masks
__all__ = ["Sam2VideoProcessor"]

View File

@ -19,9 +19,17 @@ import unittest
import requests
from transformers import Sam2Config, Sam2ImageEncoderConfig, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, pipeline
from transformers import (
Sam2Config,
Sam2ImageEncoderConfig,
Sam2MaskDecoderConfig,
Sam2Processor,
Sam2PromptEncoderConfig,
pipeline,
)
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from transformers.video_utils import load_video
from ...test_modeling_common import ModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -417,19 +425,32 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_dog_img():
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam2.png"
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_video():
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
raw_video, _ = load_video(video_url)
return raw_video
@slow
class Sam2ModelIntegrationTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
self.processor = Sam2Processor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
self.model.to(torch_device)
self.model.eval()
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
@ -437,45 +458,98 @@ class Sam2ModelIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device)
def test_inference_mask_generation_no_point(self):
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
pass
model.to(torch_device)
model.eval()
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
# self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_multimask(self):
raw_image = prepare_image()
inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
input_points = [[[[500, 375]]]]
input_labels = [[[1]]]
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
model.to(torch_device)
model.eval()
raw_image = prepare_image()
input_boxes = [[[650, 900, 1000, 1250]]]
input_points = [[[820, 1080]]]
inputs = processor(
images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
inputs = self.processor(
images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(torch_device)
# to_tensor = ToTensor()
# transforms = torch.jit.script(
# nn.Sequential(
# Resize((1024, 1024)),
# Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# )
# )
# inputs["pixel_values"] = transforms(to_tensor(raw_image)).unsqueeze(0).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
outputs = self.model(**inputs)
self.assertEqual(outputs.ious.shape, (1, 1, 3))
self.assertEqual(outputs.low_res_masks.shape, (1, 1, 3, 256, 256))
sorted_indices = torch.argsort(outputs.ious.squeeze(), descending=True)
scores = outputs.ious.squeeze()[sorted_indices]
masks_logits = outputs.low_res_masks.squeeze()[sorted_indices][0, :3, :3]
print("scores", scores)
print("masks_logits", masks_logits)
torch.testing.assert_close(
scores, torch.tensor([0.9546, 0.4937, 0.0428]).to(torch_device), atol=1e-4, rtol=1e-4
)
torch.testing.assert_close(
masks_logits,
torch.tensor(
[[-25.0963, -41.5728, -30.8723], [-34.7112, -30.7988, -36.4013], [-25.3061, -37.4575, -33.1899]]
).to(torch_device),
atol=1e-4,
rtol=1e-4,
)
def test_inference_mask_generation_video_one_point(self):
pass
# raw_video = prepare_video()
# self.processor.init_state(video_path="./videos/bedroom_light")
# inputs = processor.add_new_points_or_box(
# frame_idx=0,
# obj_id=1,
# points=[[[[210, 350]]]],
# labels=[[[1]]],
# )
# def test_inference_mask_generation_one_point_one_bb(self):
# model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
# processor = SamProcessor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_boxes = [[[[650, 900, 1000, 1250]]]]
# input_points = [[[[820, 1080]]]]
# inputs = processor(
# images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
# ).to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
# self.assertTrue(
# torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
# )
def test_inference_mask_generation_batched_points_batched_images(self):
model = Sam2Model.from_pretrained("facebook/sam2-vit-base")