mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
All tests passing
This commit is contained in:
parent
aebcb34dad
commit
978b02edc2
@ -4572,7 +4572,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
local_files_only = True
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
print("here", cls.config_class)
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
config_path,
|
||||
|
@ -222,6 +222,8 @@ class SamAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||
c_per_head = channel // num_attention_heads
|
||||
@ -265,7 +267,7 @@ class SamAttention(nn.Module):
|
||||
attention_mask=attention_similarity,
|
||||
dropout=0.0 if not self.training else self.dropout_p,
|
||||
scaling=scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -89,6 +89,9 @@ class Sam2VisionConfig(PretrainedConfig):
|
||||
|
||||
"""
|
||||
|
||||
base_config_key = "vision_config"
|
||||
model_type = "sam2_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=96,
|
||||
@ -188,6 +191,8 @@ class Sam2PromptEncoderConfig(PretrainedConfig):
|
||||
The scale factor for the prompt encoder.
|
||||
"""
|
||||
|
||||
base_config_key = "prompt_encoder_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=256,
|
||||
@ -256,6 +261,8 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
|
||||
|
||||
"""
|
||||
|
||||
base_config_key = "mask_decoder_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=256,
|
||||
@ -267,6 +274,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
|
||||
num_multimask_outputs=3,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
dynamic_multimask_via_stability=True,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
feed_forward_hidden_act="relu",
|
||||
two_way_transformer_activation="relu",
|
||||
**kwargs,
|
||||
@ -279,6 +289,9 @@ class Sam2MaskDecoderConfig(PretrainedConfig):
|
||||
self.iou_head_depth = iou_head_depth
|
||||
self.iou_head_hidden_dim = iou_head_hidden_dim
|
||||
self.feed_forward_hidden_act = feed_forward_hidden_act
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
||||
|
||||
# TwoWayTransformer configuration
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@ -329,6 +342,8 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
|
||||
|
||||
"""
|
||||
|
||||
base_config_key = "memory_attention_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=256,
|
||||
@ -404,6 +419,8 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):
|
||||
|
||||
"""
|
||||
|
||||
base_config_key = "memory_encoder_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=256,
|
||||
|
@ -38,7 +38,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
|
||||
|
||||
|
||||
@ -413,18 +413,17 @@ class Sam2VisionEncoder(nn.Module):
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, Sam2VisionEncoderOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
@ -460,14 +459,6 @@ class Sam2VisionEncoder(nn.Module):
|
||||
fpn_position_encoding[-self.num_feature_levels :][::-1],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_self_attentions,)
|
||||
return outputs
|
||||
|
||||
return Sam2VisionEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
fpn_hidden_states=fpn_hidden_states,
|
||||
@ -874,6 +865,9 @@ class Sam2MaskDecoder(nn.Module):
|
||||
|
||||
self.num_multimask_outputs = config.num_multimask_outputs
|
||||
self.num_mask_tokens = config.num_multimask_outputs + 1
|
||||
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
|
||||
|
||||
self.iou_token = nn.Embedding(1, self.hidden_size)
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
|
||||
@ -913,6 +907,53 @@ class Sam2MaskDecoder(nn.Module):
|
||||
self.obj_score_token = nn.Embedding(1, self.hidden_size)
|
||||
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""
|
||||
Compute stability scores of the mask logits based on the IoU between upper and
|
||||
lower thresholds.
|
||||
"""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
When outputting a single mask, if the stability score from the current single-mask
|
||||
output (based on output token 0) falls below a threshold, we instead select from
|
||||
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
||||
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, :, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, :, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
|
||||
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
|
||||
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
@ -1003,10 +1044,16 @@ class Sam2MaskDecoder(nn.Module):
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
mask_slice = slice(1, None)
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
mask_slice = slice(0, 1)
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
|
||||
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
|
||||
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
|
||||
|
||||
@ -1416,6 +1463,8 @@ class Sam2Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||
c_per_head = channel // num_attention_heads
|
||||
@ -1459,7 +1508,7 @@ class Sam2Attention(nn.Module):
|
||||
attention_mask=attention_similarity,
|
||||
dropout=0.0 if not self.training else self.dropout_p,
|
||||
scaling=scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -2242,13 +2291,11 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
vision_outputs = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
feature_maps = vision_outputs[1]
|
||||
@ -2265,6 +2312,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
|
||||
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2280,7 +2328,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
@ -2365,7 +2412,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
||||
@ -2410,7 +2456,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
)
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
@ -2432,14 +2477,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
if input_points is not None and input_labels is None:
|
||||
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
||||
|
||||
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
|
||||
# raise ValueError(
|
||||
# "The batch size of the image embeddings and the input points must be the same. ",
|
||||
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
|
||||
# " if you want to pass multiple points for the same image, make sure that you passed ",
|
||||
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
||||
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
# )
|
||||
if input_points is None:
|
||||
# If no points are provide, pad with an empty point (with label -1)
|
||||
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
|
||||
@ -2447,11 +2484,9 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
|
||||
)
|
||||
|
||||
# b) Handle mask prompts
|
||||
if input_masks is not None:
|
||||
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
|
||||
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
|
||||
input_masks = F.interpolate(
|
||||
input_masks.float(),
|
||||
@ -2523,15 +2558,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
high_res_masks = None
|
||||
obj_ptr = None
|
||||
|
||||
if not return_dict:
|
||||
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
|
||||
if output_hidden_states:
|
||||
output = output + (vision_hidden_states,)
|
||||
|
||||
# if output_attentions:
|
||||
# output = output + (vision_attentions, mask_decoder_attentions)
|
||||
return output
|
||||
|
||||
return Sam2ImageSegmentationOutput(
|
||||
iou_scores=iou_scores,
|
||||
low_res_masks=low_res_masks,
|
||||
@ -3039,9 +3065,9 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
||||
"""
|
||||
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
|
||||
(same input and output shapes as in _forward_sam_heads above).
|
||||
(same input and output shapes as in forward above).
|
||||
"""
|
||||
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
# Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
||||
mask_inputs_float = mask_inputs.float()
|
||||
high_res_masks = mask_inputs_float * out_scale + out_bias
|
||||
|
@ -46,7 +46,7 @@ from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig
|
||||
|
||||
|
||||
@ -482,18 +482,17 @@ class Sam2VisionEncoder(nn.Module):
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, Sam2VisionEncoderOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
@ -529,14 +528,6 @@ class Sam2VisionEncoder(nn.Module):
|
||||
fpn_position_encoding[-self.num_feature_levels :][::-1],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (hidden_states, fpn_hidden_states, fpn_position_encoding)
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_self_attentions,)
|
||||
return outputs
|
||||
|
||||
return Sam2VisionEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
fpn_hidden_states=fpn_hidden_states,
|
||||
@ -686,6 +677,9 @@ class Sam2MaskDecoder(nn.Module):
|
||||
|
||||
self.num_multimask_outputs = config.num_multimask_outputs
|
||||
self.num_mask_tokens = config.num_multimask_outputs + 1
|
||||
self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
|
||||
|
||||
self.iou_token = nn.Embedding(1, self.hidden_size)
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
|
||||
@ -725,6 +719,53 @@ class Sam2MaskDecoder(nn.Module):
|
||||
self.obj_score_token = nn.Embedding(1, self.hidden_size)
|
||||
self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3, activation="relu")
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""
|
||||
Compute stability scores of the mask logits based on the IoU between upper and
|
||||
lower thresholds.
|
||||
"""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
When outputting a single mask, if the stability score from the current single-mask
|
||||
output (based on output token 0) falls below a threshold, we instead select from
|
||||
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
||||
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, :, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, :, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
|
||||
point_batch_inds = torch.arange(multimask_iou_scores.size(1), device=all_iou_scores.device)
|
||||
best_multimask_logits = multimask_logits[batch_inds, point_batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, point_batch_inds, best_scores_inds]
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, :, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
@ -815,10 +856,16 @@ class Sam2MaskDecoder(nn.Module):
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
mask_slice = slice(1, None)
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
mask_slice = slice(0, 1)
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
|
||||
sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
|
||||
outputs = (masks, iou_pred, sam_tokens_out, object_score_logits)
|
||||
|
||||
@ -1249,6 +1296,8 @@ class Sam2Attention(SamAttention):
|
||||
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
|
||||
def init_2d_position_ids(end_x: int, end_y: int):
|
||||
"""Generate 2D position indices for axial rotary embedding."""
|
||||
@ -1956,13 +2005,11 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
vision_outputs = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
feature_maps = vision_outputs[1]
|
||||
@ -1979,6 +2026,7 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
|
||||
return feature_maps, feature_maps_position_embeddings, vision_hidden_states, vision_attentions
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1994,7 +2042,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
@ -2079,7 +2126,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
||||
@ -2124,7 +2170,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
)
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
@ -2146,14 +2191,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
if input_points is not None and input_labels is None:
|
||||
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
||||
|
||||
# if input_points is not None and image_embeddings[-1].shape[1] != input_points.shape[0]:
|
||||
# raise ValueError(
|
||||
# "The batch size of the image embeddings and the input points must be the same. ",
|
||||
# "Got {} and {} respectively.".format(image_embeddings[-1].shape[1], input_points.shape[0]),
|
||||
# " if you want to pass multiple points for the same image, make sure that you passed ",
|
||||
# " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
||||
# " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
# )
|
||||
if input_points is None:
|
||||
# If no points are provide, pad with an empty point (with label -1)
|
||||
input_points = torch.zeros(batch_size, point_batch_size, 1, 2, device=image_embeddings[-1].device)
|
||||
@ -2161,11 +2198,9 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
batch_size, point_batch_size, 1, dtype=torch.int32, device=image_embeddings[-1].device
|
||||
)
|
||||
|
||||
# b) Handle mask prompts
|
||||
if input_masks is not None:
|
||||
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||
assert len(input_masks.shape) == 4 and input_masks.shape[:2] == (batch_size, 1)
|
||||
if input_masks.shape[-2:] != self.prompt_encoder.image_embedding_size:
|
||||
input_masks = F.interpolate(
|
||||
input_masks.float(),
|
||||
@ -2237,15 +2272,6 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
high_res_masks = None
|
||||
obj_ptr = None
|
||||
|
||||
if not return_dict:
|
||||
output = (iou_scores, low_res_masks, high_res_masks, obj_ptr, object_score_logits, image_embeddings)
|
||||
if output_hidden_states:
|
||||
output = output + (vision_hidden_states,)
|
||||
|
||||
# if output_attentions:
|
||||
# output = output + (vision_attentions, mask_decoder_attentions)
|
||||
return output
|
||||
|
||||
return Sam2ImageSegmentationOutput(
|
||||
iou_scores=iou_scores,
|
||||
low_res_masks=low_res_masks,
|
||||
@ -2753,9 +2779,9 @@ class Sam2Model(Sam2PreTrainedModel):
|
||||
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
||||
"""
|
||||
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
|
||||
(same input and output shapes as in _forward_sam_heads above).
|
||||
(same input and output shapes as in forward above).
|
||||
"""
|
||||
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
# Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
||||
mask_inputs_float = mask_inputs.float()
|
||||
high_res_masks = mask_inputs_float * out_scale + out_bias
|
||||
|
@ -263,6 +263,10 @@ class Sam2VisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
# Override as diffence slightly higher than the threshold
|
||||
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
|
||||
super().test_batching_equivalence(atol=atol, rtol=rtol)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
self.skipTest(reason="SAM model can't be compiled dynamic yet")
|
||||
@ -358,6 +362,8 @@ class Sam2MemoryEncoderTester:
|
||||
patch_kernel_size=2,
|
||||
patch_stride=2,
|
||||
patch_padding=1,
|
||||
mask_downsampler_embed_dim=32,
|
||||
memory_fuser_embed_dim=32,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
@ -366,6 +372,8 @@ class Sam2MemoryEncoderTester:
|
||||
self.patch_kernel_size = patch_kernel_size
|
||||
self.patch_stride = patch_stride
|
||||
self.patch_padding = patch_padding
|
||||
self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
|
||||
self.memory_fuser_embed_dim = memory_fuser_embed_dim
|
||||
|
||||
def get_config(self):
|
||||
return Sam2MemoryEncoderConfig(
|
||||
@ -376,6 +384,8 @@ class Sam2MemoryEncoderTester:
|
||||
patch_kernel_size=self.patch_kernel_size,
|
||||
patch_stride=self.patch_stride,
|
||||
patch_padding=self.patch_padding,
|
||||
mask_downsampler_embed_dim=self.mask_downsampler_embed_dim,
|
||||
memory_fuser_embed_dim=self.memory_fuser_embed_dim,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
@ -445,6 +455,12 @@ class Sam2ModelTester:
|
||||
fpn_hidden_size=self.fpn_hidden_size,
|
||||
)
|
||||
|
||||
memory_attention_config = Sam2MemoryAttentionConfig(
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=1,
|
||||
dim_feedforward=32,
|
||||
)
|
||||
|
||||
prompt_encoder_config = self.prompt_encoder_tester.get_config()
|
||||
|
||||
mask_decoder_config = self.mask_decoder_tester.get_config()
|
||||
@ -455,7 +471,7 @@ class Sam2ModelTester:
|
||||
vision_config=vision_config,
|
||||
prompt_encoder_config=prompt_encoder_config,
|
||||
mask_decoder_config=mask_decoder_config,
|
||||
memory_attention_config=Sam2MemoryAttentionConfig(),
|
||||
memory_attention_config=memory_attention_config,
|
||||
memory_encoder_config=memory_encoder_config,
|
||||
image_size=self.image_size,
|
||||
)
|
||||
@ -467,43 +483,7 @@ class Sam2ModelTester:
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
|
||||
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
|
||||
|
||||
def create_and_check_get_image_features(self, config, pixel_values):
|
||||
model = Sam2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model.get_image_embeddings(pixel_values)
|
||||
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
|
||||
|
||||
def create_and_check_get_image_hidden_states(self, config, pixel_values):
|
||||
model = Sam2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model.vision_encoder(
|
||||
pixel_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# after computing the convolutional features
|
||||
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model.vision_encoder(
|
||||
pixel_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# after computing the convolutional features
|
||||
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||
self.parent.assertEqual(result.low_res_masks.shape[:3], (self.batch_size, 1, 3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@ -557,14 +537,6 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_get_image_features(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
|
||||
|
||||
def test_image_hidden_states(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
|
||||
|
||||
# Overriding as attention shape depends on window_size
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -620,24 +592,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
[num_windows, window_size, window_size, out_dim],
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Sam2Model does not support training")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# model_name = "facebook/sam-vit-huge"
|
||||
# model = SamModel.from_pretrained(model_name)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
|
||||
|
||||
# Override as Sam2Model has different sub-modules
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
"""
|
||||
@ -662,7 +617,7 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder")
|
||||
@ -687,44 +642,116 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
# # Overriding as attention shape depends on window_size
|
||||
# def test_hidden_states_output(self):
|
||||
# def check_hidden_states_output(inputs_dict, config, model_class, image_size):
|
||||
# model = model_class(config)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
# Override as Sam2Model doesn't have hidden states
|
||||
def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
|
||||
r"""
|
||||
Tests the equivalence between the eager and flash attention implementations.
|
||||
This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
# with torch.no_grad():
|
||||
# outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
for model_class in self.all_model_classes:
|
||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
||||
):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
# hidden_states = outputs.hidden_states
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
# expected_num_layers = sum(self.model_tester.stages) + 1
|
||||
# self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
# self.assertListEqual(
|
||||
# list(hidden_states[0].shape[-4:]),
|
||||
# [
|
||||
# self.model_tester.batch_size,
|
||||
# self.model_tester.image_size // self.model_tester.patch_stride,
|
||||
# self.model_tester.image_size // self.model_tester.patch_stride,
|
||||
# self.model_tester.hidden_size,
|
||||
# ],
|
||||
# )
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# image_size = self.model_tester.image_size
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
# for model_class in self.all_model_classes:
|
||||
# inputs_dict["output_hidden_states"] = True
|
||||
# check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
else:
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
if model.config.is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
# # check that output_hidden_states also work using config
|
||||
# del inputs_dict["output_hidden_states"]
|
||||
# config.output_hidden_states = True
|
||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
else:
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
# check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
logits = outputs.vision_hidden_states[-1]
|
||||
logits_fa = outputs_fa.vision_hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
|
||||
logits = outputs.vision_hidden_states[-1]
|
||||
logits_fa = outputs_fa.vision_hidden_states[-1]
|
||||
|
||||
if padding_side == "left":
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, **other_inputs)
|
||||
else:
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# Override as diffence slightly higher than the threshold
|
||||
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
|
||||
super().test_batching_equivalence(atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skip(reason="Sam2Model does not support training")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in sub modules tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# model_name = "facebook/sam-vit-huge"
|
||||
# model = SamModel.from_pretrained(model_name)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
|
||||
|
||||
|
||||
def prepare_image():
|
||||
|
Loading…
Reference in New Issue
Block a user