All tests passing

This commit is contained in:
yonigozlan 2025-07-02 20:37:38 +00:00
parent aebcb34dad
commit 978b02edc2
6 changed files with 272 additions and 175 deletions

View File

@ -4572,7 +4572,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
print("here", cls.config_class)
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,

View File

@ -222,6 +222,8 @@ class SamAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
@ -265,7 +267,7 @@ class SamAttention(nn.Module):
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=scale,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)

View File

@ -89,6 +89,9 @@ class Sam2VisionConfig(PretrainedConfig):
"""
base_config_key = "vision_config"
model_type = "sam2_vision_model"
def __init__(
self,
hidden_size=96,
@ -188,6 +191,8 @@ class Sam2PromptEncoderConfig(PretrainedConfig):
The scale factor for the prompt encoder.
"""
base_config_key = "prompt_encoder_config"
def __init__(
self,
hidden_size=256,
@ -256,6 +261,8 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
"""
base_config_key = "mask_decoder_config"
def __init__(
self,
hidden_size=256,
@ -267,6 +274,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=256,
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
feed_forward_hidden_act="relu",
two_way_transformer_activation="relu",
**kwargs,
@ -279,6 +289,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
self.feed_forward_hidden_act = feed_forward_hidden_act
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
# TwoWayTransformer configuration
self.num_hidden_layers = num_hidden_layers
@ -329,6 +342,8 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
"""
base_config_key = "memory_attention_config"
def __init__(
self,
hidden_size=256,
@ -404,6 +419,8 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):
"""
base_config_key = "memory_encoder_config"
def __init__(
self,
hidden_size=256,

View File

@ -38,7 +38,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, logging
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
@ -413,18 +413,17 @@ class Sam2VisionEncoder(nn.Module):
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, Sam2VisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
@ -460,14 +459,6 @@ class Sam2VisionEncoder(nn.Module):
fpn_position_encoding[-self.num_feature_levels :][::-1],
)
if not return_dict:
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
@ -874,6 +865,9 @@ class Sam2MaskDecoder(nn.Module):
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
@ -913,6 +907,53 @@ class Sam2MaskDecoder(nn.Module):
self.obj_score_token = nn.Embedding(1, self.hidden_size)
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, :, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, :, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
def forward(
self,
image_embeddings: torch.Tensor,
@ -1003,10 +1044,16 @@ class Sam2MaskDecoder(nn.Module):
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
elif self.dynamic_multimask_via_stability and not self.training:
mask_slice = slice(0, 1)
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
@ -1416,6 +1463,8 @@ class Sam2Attention(nn.Module):
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
@ -1459,7 +1508,7 @@ class Sam2Attention(nn.Module):
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=scale,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)
@ -2242,13 +2291,11 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values: torch.FloatTensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
feature_maps = vision_outputs[1]
@ -2265,6 +2312,7 @@ class Sam2Model(Sam2PreTrainedModel):
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
@can_return_tuple
@auto_docstring
def forward(
self,
@ -2280,7 +2328,6 @@ class Sam2Model(Sam2PreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> list[dict[str, torch.Tensor]]:
r"""
@ -2365,7 +2412,6 @@ class Sam2Model(Sam2PreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -2410,7 +2456,6 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)
# flatten NxCxHxW to HWxNxC
@ -2432,14 +2477,6 @@ class Sam2Model(Sam2PreTrainedModel):
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
# raise ValueError(
# "The batch size of the image embeddings and the input points must be the same. ",
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
# " if you want to pass multiple points for the same image, make sure that you passed ",
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
# )
if input_points is None:
# If no points are provide, pad with an empty point (with label -1)
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
@ -2447,11 +2484,9 @@ class Sam2Model(Sam2PreTrainedModel):
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
)
# b) Handle mask prompts
if input_masks is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
input_masks = F.interpolate(
input_masks.float(),
@ -2523,15 +2558,6 @@ class Sam2Model(Sam2PreTrainedModel):
high_res_masks = None
obj_ptr = None
if not return_dict:
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
if output_hidden_states:
output = output + (vision_hidden_states,)
# if output_attentions:
# output = output + (vision_attentions, mask_decoder_attentions)
return output
return Sam2ImageSegmentationOutput(
iou_scores=iou_scores,
low_res_masks=low_res_masks,
@ -3039,9 +3065,9 @@ class Sam2Model(Sam2PreTrainedModel):
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
(same input and output shapes as in _forward_sam_heads above).
(same input and output shapes as in forward above).
"""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
# Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias

View File

@ -46,7 +46,7 @@ from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, logging
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
@ -482,18 +482,17 @@ class Sam2VisionEncoder(nn.Module):
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
@can_return_tuple
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, Sam2VisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
@ -529,14 +528,6 @@ class Sam2VisionEncoder(nn.Module):
fpn_position_encoding[-self.num_feature_levels :][::-1],
)
if not return_dict:
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
@ -686,6 +677,9 @@ class Sam2MaskDecoder(nn.Module):
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
@ -725,6 +719,53 @@ class Sam2MaskDecoder(nn.Module):
self.obj_score_token = nn.Embedding(1, self.hidden_size)
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, :, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, :, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
def forward(
self,
image_embeddings: torch.Tensor,
@ -815,10 +856,16 @@ class Sam2MaskDecoder(nn.Module):
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
elif self.dynamic_multimask_via_stability and not self.training:
mask_slice = slice(0, 1)
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
@ -1249,6 +1296,8 @@ class Sam2Attention(SamAttention):
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def init_2d_position_ids(end_x: int, end_y: int):
"""Generate 2D position indices for axial rotary embedding."""
@ -1956,13 +2005,11 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values: torch.FloatTensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
feature_maps = vision_outputs[1]
@ -1979,6 +2026,7 @@ class Sam2Model(Sam2PreTrainedModel):
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1994,7 +2042,6 @@ class Sam2Model(Sam2PreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> list[dict[str, torch.Tensor]]:
r"""
@ -2079,7 +2126,6 @@ class Sam2Model(Sam2PreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -2124,7 +2170,6 @@ class Sam2Model(Sam2PreTrainedModel):
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)
# flatten NxCxHxW to HWxNxC
@ -2146,14 +2191,6 @@ class Sam2Model(Sam2PreTrainedModel):
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
# raise ValueError(
# "The batch size of the image embeddings and the input points must be the same. ",
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
# " if you want to pass multiple points for the same image, make sure that you passed ",
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
# )
if input_points is None:
# If no points are provide, pad with an empty point (with label -1)
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
@ -2161,11 +2198,9 @@ class Sam2Model(Sam2PreTrainedModel):
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
)
# b) Handle mask prompts
if input_masks is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
input_masks = F.interpolate(
input_masks.float(),
@ -2237,15 +2272,6 @@ class Sam2Model(Sam2PreTrainedModel):
high_res_masks = None
obj_ptr = None
if not return_dict:
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
if output_hidden_states:
output = output + (vision_hidden_states,)
# if output_attentions:
# output = output + (vision_attentions, mask_decoder_attentions)
return output
return Sam2ImageSegmentationOutput(
iou_scores=iou_scores,
low_res_masks=low_res_masks,
@ -2753,9 +2779,9 @@ class Sam2Model(Sam2PreTrainedModel):
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
(same input and output shapes as in _forward_sam_heads above).
(same input and output shapes as in forward above).
"""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
# Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias

View File

@ -263,6 +263,10 @@ class Sam2VisionModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class, image_size)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet")
@ -358,6 +362,8 @@ class Sam2MemoryEncoderTester:
patch_kernel_size=2,
patch_stride=2,
patch_padding=1,
mask_downsampler_embed_dim=32,
memory_fuser_embed_dim=32,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
@ -366,6 +372,8 @@ class Sam2MemoryEncoderTester:
self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
self.memory_fuser_embed_dim = memory_fuser_embed_dim
def get_config(self):
return Sam2MemoryEncoderConfig(
@ -376,6 +384,8 @@ class Sam2MemoryEncoderTester:
patch_kernel_size=self.patch_kernel_size,
patch_stride=self.patch_stride,
patch_padding=self.patch_padding,
mask_downsampler_embed_dim=self.mask_downsampler_embed_dim,
memory_fuser_embed_dim=self.memory_fuser_embed_dim,
)
def prepare_config_and_inputs(self):
@ -445,6 +455,12 @@ class Sam2ModelTester:
fpn_hidden_size=self.fpn_hidden_size,
)
memory_attention_config = Sam2MemoryAttentionConfig(
hidden_size=self.hidden_size,
num_layers=1,
dim_feedforward=32,
)
prompt_encoder_config = self.prompt_encoder_tester.get_config()
mask_decoder_config = self.mask_decoder_tester.get_config()
@ -455,7 +471,7 @@ class Sam2ModelTester:
vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config,
mask_decoder_config=mask_decoder_config,
memory_attention_config=Sam2MemoryAttentionConfig(),
memory_attention_config=memory_attention_config,
memory_encoder_config=memory_encoder_config,
image_size=self.image_size,
)
@ -467,43 +483,7 @@ class Sam2ModelTester:
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
def create_and_check_get_image_features(self, config, pixel_values):
model = Sam2Model(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model.get_image_embeddings(pixel_values)
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
def create_and_check_get_image_hidden_states(self, config, pixel_values):
model = Sam2Model(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=True,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
with torch.no_grad():
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=False,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
self.parent.assertEqual(result.low_res_masks.shape[:3], (self.batch_size, 1, 3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@ -557,14 +537,6 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_get_image_features(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
def test_image_hidden_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
# Overriding as attention shape depends on window_size
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -620,24 +592,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
[num_windows, window_size, window_size, out_dim],
)
@unittest.skip(reason="Sam2Model does not support training")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
def test_hidden_states_output(self):
pass
# @slow
# def test_model_from_pretrained(self):
# model_name = "facebook/sam-vit-huge"
# model = SamModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
# Override as Sam2Model has different sub-modules
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
@ -662,7 +617,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device)
vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder")
@ -687,44 +642,116 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
):
raise ValueError("The eager model should not have SDPA attention layers")
# # Overriding as attention shape depends on window_size
# def test_hidden_states_output(self):
# def check_hidden_states_output(inputs_dict, config, model_class, image_size):
# model = model_class(config)
# model.to(torch_device)
# model.eval()
# Override as Sam2Model doesn't have hidden states
def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
r"""
Tests the equivalence between the eager and flash attention implementations.
This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
# with torch.no_grad():
# outputs = model(**self._prepare_for_class(inputs_dict, model_class))
for model_class in self.all_model_classes:
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
# hidden_states = outputs.hidden_states
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# expected_num_layers = sum(self.model_tester.stages) + 1
# self.assertEqual(len(hidden_states), expected_num_layers)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
# self.assertListEqual(
# list(hidden_states[0].shape[-4:]),
# [
# self.model_tester.batch_size,
# self.model_tester.image_size // self.model_tester.patch_stride,
# self.model_tester.image_size // self.model_tester.patch_stride,
# self.model_tester.hidden_size,
# ],
# )
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# image_size = self.model_tester.image_size
dummy_attention_mask = inputs_dict.get("attention_mask", None)
# for model_class in self.all_model_classes:
# inputs_dict["output_hidden_states"] = True
# check_hidden_states_output(inputs_dict, config, model_class, image_size)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
if padding_side == "left":
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
else:
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
# # check that output_hidden_states also work using config
# del inputs_dict["output_hidden_states"]
# config.output_hidden_states = True
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
# check_hidden_states_output(inputs_dict, config, model_class, image_size)
logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
if padding_side == "left":
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@unittest.skip(reason="Sam2Model does not support training")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Hidden_states is tested in sub modules tests")
def test_hidden_states_output(self):
pass
# @slow
# def test_model_from_pretrained(self):
# model_name = "facebook/sam-vit-huge"
# model = SamModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
def prepare_image():