All tests passing

This commit is contained in:
yonigozlan 2025-07-02 20:37:38 +00:00
parent aebcb34dad
commit 978b02edc2
6 changed files with 272 additions and 175 deletions

View File

@ -4572,7 +4572,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
local_files_only = True local_files_only = True
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
print("here", cls.config_class)
config_path = config if config is not None else pretrained_model_name_or_path config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained( config, model_kwargs = cls.config_class.from_pretrained(
config_path, config_path,

View File

@ -222,6 +222,8 @@ class SamAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads c_per_head = channel // num_attention_heads
@ -265,7 +267,7 @@ class SamAttention(nn.Module):
attention_mask=attention_similarity, attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p, dropout=0.0 if not self.training else self.dropout_p,
scaling=scale, scaling=scale,
is_causal=False, is_causal=self.is_causal,
**kwargs, **kwargs,
) )

View File

@ -89,6 +89,9 @@ class Sam2VisionConfig(PretrainedConfig):
""" """
base_config_key = "vision_config"
model_type = "sam2_vision_model"
def __init__( def __init__(
self, self,
hidden_size=96, hidden_size=96,
@ -188,6 +191,8 @@ class Sam2PromptEncoderConfig(PretrainedConfig):
The scale factor for the prompt encoder. The scale factor for the prompt encoder.
""" """
base_config_key = "prompt_encoder_config"
def __init__( def __init__(
self, self,
hidden_size=256, hidden_size=256,
@ -256,6 +261,8 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
""" """
base_config_key = "mask_decoder_config"
def __init__( def __init__(
self, self,
hidden_size=256, hidden_size=256,
@ -267,6 +274,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
num_multimask_outputs=3, num_multimask_outputs=3,
iou_head_depth=3, iou_head_depth=3,
iou_head_hidden_dim=256, iou_head_hidden_dim=256,
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
feed_forward_hidden_act="relu", feed_forward_hidden_act="relu",
two_way_transformer_activation="relu", two_way_transformer_activation="relu",
**kwargs, **kwargs,
@ -279,6 +289,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
self.iou_head_depth = iou_head_depth self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim self.iou_head_hidden_dim = iou_head_hidden_dim
self.feed_forward_hidden_act = feed_forward_hidden_act self.feed_forward_hidden_act = feed_forward_hidden_act
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
# TwoWayTransformer configuration # TwoWayTransformer configuration
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
@ -329,6 +342,8 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
""" """
base_config_key = "memory_attention_config"
def __init__( def __init__(
self, self,
hidden_size=256, hidden_size=256,
@ -404,6 +419,8 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):
""" """
base_config_key = "memory_encoder_config"
def __init__( def __init__(
self, self,
hidden_size=256, hidden_size=256,

View File

@ -38,7 +38,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
@ -413,18 +413,17 @@ class Sam2VisionEncoder(nn.Module):
pos_embed = pos_embed.permute(0, 2, 3, 1) pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed return pos_embed
@can_return_tuple
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, Sam2VisionEncoderOutput]: ) -> Union[tuple, Sam2VisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -460,14 +459,6 @@ class Sam2VisionEncoder(nn.Module):
fpn_position_encoding[-self.num_feature_levels :][::-1], fpn_position_encoding[-self.num_feature_levels :][::-1],
) )
if not return_dict:
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return Sam2VisionEncoderOutput( return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states, fpn_hidden_states=fpn_hidden_states,
@ -874,6 +865,9 @@ class Sam2MaskDecoder(nn.Module):
self.num_multimask_outputs = config.num_multimask_outputs self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1 self.num_mask_tokens = config.num_multimask_outputs + 1
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
self.iou_token = nn.Embedding(1, self.hidden_size) self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
@ -913,6 +907,53 @@ class Sam2MaskDecoder(nn.Module):
self.obj_score_token = nn.Embedding(1, self.hidden_size) self.obj_score_token = nn.Embedding(1, self.hidden_size)
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu") self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, :, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, :, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
def forward( def forward(
self, self,
image_embeddings: torch.Tensor, image_embeddings: torch.Tensor,
@ -1003,10 +1044,16 @@ class Sam2MaskDecoder(nn.Module):
# Select the correct mask or masks for output # Select the correct mask or masks for output
if multimask_output: if multimask_output:
mask_slice = slice(1, None) mask_slice = slice(1, None)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
elif self.dynamic_multimask_via_stability and not self.training:
mask_slice = slice(0, 1)
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else: else:
mask_slice = slice(0, 1) mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :] masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice] iou_pred = iou_pred[:, :, mask_slice]
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits) outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
@ -1416,6 +1463,8 @@ class Sam2Attention(nn.Module):
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads c_per_head = channel // num_attention_heads
@ -1459,7 +1508,7 @@ class Sam2Attention(nn.Module):
attention_mask=attention_similarity, attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p, dropout=0.0 if not self.training else self.dropout_p,
scaling=scale, scaling=scale,
is_causal=False, is_causal=self.is_causal,
**kwargs, **kwargs,
) )
@ -2242,13 +2291,11 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = True,
): ):
vision_outputs = self.vision_encoder( vision_outputs = self.vision_encoder(
pixel_values, pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
feature_maps = vision_outputs[1] feature_maps = vision_outputs[1]
@ -2265,6 +2312,7 @@ class Sam2Model(Sam2PreTrainedModel):
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -2280,7 +2328,6 @@ class Sam2Model(Sam2PreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None, target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, **kwargs,
) -> list[dict[str, torch.Tensor]]: ) -> list[dict[str, torch.Tensor]]:
r""" r"""
@ -2365,7 +2412,6 @@ class Sam2Model(Sam2PreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None: if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.") raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -2410,7 +2456,6 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values, pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
) )
# flatten NxCxHxW to HWxNxC # flatten NxCxHxW to HWxNxC
@ -2432,14 +2477,6 @@ class Sam2Model(Sam2PreTrainedModel):
if input_points is not None and input_labels is None: if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
# raise ValueError(
# "The batch size of the image embeddings and the input points must be the same. ",
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
# " if you want to pass multiple points for the same image, make sure that you passed ",
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
# )
if input_points is None: if input_points is None:
# If no points are provide, pad with an empty point (with label -1) # If no points are provide, pad with an empty point (with label -1)
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device) input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
@ -2447,11 +2484,9 @@ class Sam2Model(Sam2PreTrainedModel):
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
) )
# b) Handle mask prompts
if input_masks is not None: if input_masks is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed # If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder # and feed it as a dense mask prompt into the SAM mask encoder
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size: if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
input_masks = F.interpolate( input_masks = F.interpolate(
input_masks.float(), input_masks.float(),
@ -2523,15 +2558,6 @@ class Sam2Model(Sam2PreTrainedModel):
high_res_masks = None high_res_masks = None
obj_ptr = None obj_ptr = None
if not return_dict:
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
if output_hidden_states:
output = output + (vision_hidden_states,)
# if output_attentions:
# output = output + (vision_attentions, mask_decoder_attentions)
return output
return Sam2ImageSegmentationOutput( return Sam2ImageSegmentationOutput(
iou_scores=iou_scores, iou_scores=iou_scores,
low_res_masks=low_res_masks, low_res_masks=low_res_masks,
@ -3039,9 +3065,9 @@ class Sam2Model(Sam2PreTrainedModel):
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
""" """
Directly turn binary `mask_inputs` into a output mask logits without using SAM. Directly turn binary `mask_inputs` into a output mask logits without using SAM.
(same input and output shapes as in _forward_sam_heads above). (same input and output shapes as in forward above).
""" """
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float() mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias high_res_masks = mask_inputs_float * out_scale + out_bias

View File

@ -46,7 +46,7 @@ from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
@ -482,18 +482,17 @@ class Sam2VisionEncoder(nn.Module):
pos_embed = pos_embed.permute(0, 2, 3, 1) pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed return pos_embed
@can_return_tuple
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, Sam2VisionEncoderOutput]: ) -> Union[tuple, Sam2VisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -529,14 +528,6 @@ class Sam2VisionEncoder(nn.Module):
fpn_position_encoding[-self.num_feature_levels :][::-1], fpn_position_encoding[-self.num_feature_levels :][::-1],
) )
if not return_dict:
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return Sam2VisionEncoderOutput( return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states, fpn_hidden_states=fpn_hidden_states,
@ -686,6 +677,9 @@ class Sam2MaskDecoder(nn.Module):
self.num_multimask_outputs = config.num_multimask_outputs self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1 self.num_mask_tokens = config.num_multimask_outputs + 1
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
self.iou_token = nn.Embedding(1, self.hidden_size) self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
@ -725,6 +719,53 @@ class Sam2MaskDecoder(nn.Module):
self.obj_score_token = nn.Embedding(1, self.hidden_size) self.obj_score_token = nn.Embedding(1, self.hidden_size)
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu") self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, :, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, :, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
def forward( def forward(
self, self,
image_embeddings: torch.Tensor, image_embeddings: torch.Tensor,
@ -815,10 +856,16 @@ class Sam2MaskDecoder(nn.Module):
# Select the correct mask or masks for output # Select the correct mask or masks for output
if multimask_output: if multimask_output:
mask_slice = slice(1, None) mask_slice = slice(1, None)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
elif self.dynamic_multimask_via_stability and not self.training:
mask_slice = slice(0, 1)
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else: else:
mask_slice = slice(0, 1) mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :] masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice] iou_pred = iou_pred[:, :, mask_slice]
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits) outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
@ -1249,6 +1296,8 @@ class Sam2Attention(SamAttention):
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def init_2d_position_ids(end_x: int, end_y: int): def init_2d_position_ids(end_x: int, end_y: int):
"""Generate 2D position indices for axial rotary embedding.""" """Generate 2D position indices for axial rotary embedding."""
@ -1956,13 +2005,11 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = True,
): ):
vision_outputs = self.vision_encoder( vision_outputs = self.vision_encoder(
pixel_values, pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
feature_maps = vision_outputs[1] feature_maps = vision_outputs[1]
@ -1979,6 +2026,7 @@ class Sam2Model(Sam2PreTrainedModel):
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1994,7 +2042,6 @@ class Sam2Model(Sam2PreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None, target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, **kwargs,
) -> list[dict[str, torch.Tensor]]: ) -> list[dict[str, torch.Tensor]]:
r""" r"""
@ -2079,7 +2126,6 @@ class Sam2Model(Sam2PreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None: if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.") raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -2124,7 +2170,6 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values, pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
) )
# flatten NxCxHxW to HWxNxC # flatten NxCxHxW to HWxNxC
@ -2146,14 +2191,6 @@ class Sam2Model(Sam2PreTrainedModel):
if input_points is not None and input_labels is None: if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
# raise ValueError(
# "The batch size of the image embeddings and the input points must be the same. ",
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
# " if you want to pass multiple points for the same image, make sure that you passed ",
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
# )
if input_points is None: if input_points is None:
# If no points are provide, pad with an empty point (with label -1) # If no points are provide, pad with an empty point (with label -1)
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device) input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
@ -2161,11 +2198,9 @@ class Sam2Model(Sam2PreTrainedModel):
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
) )
# b) Handle mask prompts
if input_masks is not None: if input_masks is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed # If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder # and feed it as a dense mask prompt into the SAM mask encoder
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size: if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
input_masks = F.interpolate( input_masks = F.interpolate(
input_masks.float(), input_masks.float(),
@ -2237,15 +2272,6 @@ class Sam2Model(Sam2PreTrainedModel):
high_res_masks = None high_res_masks = None
obj_ptr = None obj_ptr = None
if not return_dict:
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
if output_hidden_states:
output = output + (vision_hidden_states,)
# if output_attentions:
# output = output + (vision_attentions, mask_decoder_attentions)
return output
return Sam2ImageSegmentationOutput( return Sam2ImageSegmentationOutput(
iou_scores=iou_scores, iou_scores=iou_scores,
low_res_masks=low_res_masks, low_res_masks=low_res_masks,
@ -2753,9 +2779,9 @@ class Sam2Model(Sam2PreTrainedModel):
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
""" """
Directly turn binary `mask_inputs` into a output mask logits without using SAM. Directly turn binary `mask_inputs` into a output mask logits without using SAM.
(same input and output shapes as in _forward_sam_heads above). (same input and output shapes as in forward above).
""" """
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float() mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias high_res_masks = mask_inputs_float * out_scale + out_bias

View File

@ -263,6 +263,10 @@ class Sam2VisionModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class, image_size) check_hidden_states_output(inputs_dict, config, model_class, image_size)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet") self.skipTest(reason="SAM model can't be compiled dynamic yet")
@ -358,6 +362,8 @@ class Sam2MemoryEncoderTester:
patch_kernel_size=2, patch_kernel_size=2,
patch_stride=2, patch_stride=2,
patch_padding=1, patch_padding=1,
mask_downsampler_embed_dim=32,
memory_fuser_embed_dim=32,
): ):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
@ -366,6 +372,8 @@ class Sam2MemoryEncoderTester:
self.patch_kernel_size = patch_kernel_size self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride self.patch_stride = patch_stride
self.patch_padding = patch_padding self.patch_padding = patch_padding
self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
self.memory_fuser_embed_dim = memory_fuser_embed_dim
def get_config(self): def get_config(self):
return Sam2MemoryEncoderConfig( return Sam2MemoryEncoderConfig(
@ -376,6 +384,8 @@ class Sam2MemoryEncoderTester:
patch_kernel_size=self.patch_kernel_size, patch_kernel_size=self.patch_kernel_size,
patch_stride=self.patch_stride, patch_stride=self.patch_stride,
patch_padding=self.patch_padding, patch_padding=self.patch_padding,
mask_downsampler_embed_dim=self.mask_downsampler_embed_dim,
memory_fuser_embed_dim=self.memory_fuser_embed_dim,
) )
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
@ -445,6 +455,12 @@ class Sam2ModelTester:
fpn_hidden_size=self.fpn_hidden_size, fpn_hidden_size=self.fpn_hidden_size,
) )
memory_attention_config = Sam2MemoryAttentionConfig(
hidden_size=self.hidden_size,
num_layers=1,
dim_feedforward=32,
)
prompt_encoder_config = self.prompt_encoder_tester.get_config() prompt_encoder_config = self.prompt_encoder_tester.get_config()
mask_decoder_config = self.mask_decoder_tester.get_config() mask_decoder_config = self.mask_decoder_tester.get_config()
@ -455,7 +471,7 @@ class Sam2ModelTester:
vision_config=vision_config, vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config, prompt_encoder_config=prompt_encoder_config,
mask_decoder_config=mask_decoder_config, mask_decoder_config=mask_decoder_config,
memory_attention_config=Sam2MemoryAttentionConfig(), memory_attention_config=memory_attention_config,
memory_encoder_config=memory_encoder_config, memory_encoder_config=memory_encoder_config,
image_size=self.image_size, image_size=self.image_size,
) )
@ -467,43 +483,7 @@ class Sam2ModelTester:
with torch.no_grad(): with torch.no_grad():
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) self.parent.assertEqual(result.low_res_masks.shape[:3], (self.batch_size, 1, 3))
def create_and_check_get_image_features(self, config, pixel_values):
model = Sam2Model(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model.get_image_embeddings(pixel_values)
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
def create_and_check_get_image_hidden_states(self, config, pixel_values):
model = Sam2Model(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=True,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
with torch.no_grad():
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=False,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
@ -557,14 +537,6 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_get_image_features(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
def test_image_hidden_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
# Overriding as attention shape depends on window_size # Overriding as attention shape depends on window_size
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -620,24 +592,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
[num_windows, window_size, window_size, out_dim], [num_windows, window_size, window_size, out_dim],
) )
@unittest.skip(reason="Sam2Model does not support training") # Override as Sam2Model has different sub-modules
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
def test_hidden_states_output(self):
pass
# @slow
# def test_model_from_pretrained(self):
# model_name = "facebook/sam-vit-huge"
# model = SamModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self): def test_sdpa_can_dispatch_composite_models(self):
""" """
@ -662,7 +617,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device) model_sdpa = model_sdpa.eval().to(torch_device)
vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder") vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder")
@ -687,44 +642,116 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
): ):
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
# # Overriding as attention shape depends on window_size # Override as Sam2Model doesn't have hidden states
# def test_hidden_states_output(self): def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
# def check_hidden_states_output(inputs_dict, config, model_class, image_size): r"""
# model = model_class(config) Tests the equivalence between the eager and flash attention implementations.
# model.to(torch_device) This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
# model.eval() """
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
# with torch.no_grad(): for model_class in self.all_model_classes:
# outputs = model(**self._prepare_for_class(inputs_dict, model_class)) if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
# hidden_states = outputs.hidden_states config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# expected_num_layers = sum(self.model_tester.stages) + 1 with tempfile.TemporaryDirectory() as tmpdirname:
# self.assertEqual(len(hidden_states), expected_num_layers) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
# self.assertListEqual( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
# list(hidden_states[0].shape[-4:]), model.to(torch_device)
# [
# self.model_tester.batch_size,
# self.model_tester.image_size // self.model_tester.patch_stride,
# self.model_tester.image_size // self.model_tester.patch_stride,
# self.model_tester.hidden_size,
# ],
# )
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# image_size = self.model_tester.image_size dummy_attention_mask = inputs_dict.get("attention_mask", None)
# for model_class in self.all_model_classes: if dummy_attention_mask is not None:
# inputs_dict["output_hidden_states"] = True dummy_attention_mask = dummy_attention_mask[:1]
# check_hidden_states_output(inputs_dict, config, model_class, image_size) if padding_side == "left":
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
else:
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
# # check that output_hidden_states also work using config outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# del inputs_dict["output_hidden_states"] outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# config.output_hidden_states = True else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
# check_hidden_states_output(inputs_dict, config, model_class, image_size) logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
if padding_side == "left":
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@unittest.skip(reason="Sam2Model does not support training")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Hidden_states is tested in sub modules tests")
def test_hidden_states_output(self):
pass
# @slow
# def test_model_from_pretrained(self):
# model_name = "facebook/sam-vit-huge"
# model = SamModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
def prepare_image(): def prepare_image():