This commit is contained in:
Arthur 2025-07-03 10:14:28 +02:00
parent 17cf5424b0
commit c4d43c5324
4 changed files with 144 additions and 450 deletions

View File

@ -21,13 +21,15 @@ from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn
from transformers.utils.generic import TransformersKwargs, check_model_inputs
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
auto_docstring,
@ -329,7 +331,6 @@ class SamTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
if self.skip_first_layer_pe:
@ -364,15 +365,7 @@ class SamTwoWayAttentionBlock(nn.Module):
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
return query, keys, attn_out
class SamTwoWayTransformer(nn.Module):
@ -396,16 +389,7 @@ class SamTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
@ -421,18 +405,13 @@ class SamTwoWayTransformer(nn.Module):
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer(
queries, keys, _ = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
@ -441,7 +420,7 @@ class SamTwoWayTransformer(nn.Module):
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
return queries, keys
class SamFeedForward(nn.Module):
@ -468,9 +447,11 @@ class SamFeedForward(nn.Module):
return hidden_states
class SamMaskDecoder(nn.Module):
class SamMaskDecoder(PreTrainedModel):
_can_return_tuple = {"attentions": (SamTwoWayAttentionBlock, 2)}
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
@ -504,9 +485,9 @@ class SamMaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: Optional[torch.Tensor] = None,
target_embedding: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
@ -522,8 +503,6 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
@ -543,13 +522,12 @@ class SamMaskDecoder(nn.Module):
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer(
point_embedding, image_embeddings = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
@ -583,15 +561,7 @@ class SamMaskDecoder(nn.Module):
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
return masks, iou_pred
class SamPositionalEmbedding(nn.Module):
@ -887,13 +857,7 @@ class SamVisionAttention(nn.Module):
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
return attn_output, attn_weights
class SamVisionSdpaAttention(SamVisionAttention):
@ -951,7 +915,6 @@ class SamVisionSdpaAttention(SamVisionAttention):
)
attn_output = self.proj(attn_output)
return attn_output, None
@ -1024,13 +987,8 @@ class SamVisionLayer(GradientCheckpointingLayer):
hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
# Window partition
if self.window_size > 0:
@ -1039,7 +997,6 @@ class SamVisionLayer(GradientCheckpointingLayer):
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
)
# Reverse window partition
if self.window_size > 0:
@ -1048,12 +1005,7 @@ class SamVisionLayer(GradientCheckpointingLayer):
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
return hidden_states
class SamVisionNeck(nn.Module):
@ -1076,9 +1028,11 @@ class SamVisionNeck(nn.Module):
return hidden_states
class SamVisionEncoder(nn.Module):
class SamVisionEncoder(PreTrainedModel):
_can_record_outputs = {"hidden_states": (SamVisionLayer, 0), "attentions": (SamVisionAttention, 0)}
def __init__(self, config: SamVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
self.image_size = config.image_size
@ -1111,49 +1065,21 @@ class SamVisionEncoder(nn.Module):
def get_input_embeddings(self):
return self.patch_embed
@can_return_tuple
@check_model_inputs
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
) -> SamVisionEncoderOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
hidden_states = self.neck(hidden_states)
return SamVisionEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@ -1197,8 +1123,6 @@ class SamVisionModel(SamPreTrainedModel):
def __init__(self, config: SamVisionConfig):
super().__init__(config)
self.vision_encoder = SamVisionEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
@ -1208,16 +1132,9 @@ class SamVisionModel(SamPreTrainedModel):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, SamVisionEncoderOutput]:
return self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return self.vision_encoder(pixel_values, **kwargs)
@auto_docstring(
@ -1262,12 +1179,7 @@ class SamModel(SamPreTrainedModel):
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
@torch.no_grad()
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
@ -1281,8 +1193,7 @@ class SamModel(SamPreTrainedModel):
"""
vision_output = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)
image_embeddings = vision_output[0]
return image_embeddings
@ -1333,9 +1244,7 @@ class SamModel(SamPreTrainedModel):
multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> SamImageSegmentationOutput:
r"""
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
@ -1415,11 +1324,6 @@ class SamModel(SamPreTrainedModel):
... )
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -1453,17 +1357,10 @@ class SamModel(SamPreTrainedModel):
vision_hidden_states = None
if pixel_values is not None:
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
image_embeddings = vision_outputs.last_hidden_state
if output_hidden_states:
vision_hidden_states = vision_outputs.hidden_states
if output_attentions:
vision_attentions = vision_outputs.attentions
vision_hidden_states = vision_outputs.hidden_states
vision_attentions = vision_outputs.attentions
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
@ -1484,7 +1381,7 @@ class SamModel(SamPreTrainedModel):
input_masks=input_masks,
)
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
@ -1492,7 +1389,7 @@ class SamModel(SamPreTrainedModel):
multimask_output=multimask_output,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
**kwargs,
)
return SamImageSegmentationOutput(
@ -1500,7 +1397,6 @@ class SamModel(SamPreTrainedModel):
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
)

View File

@ -28,11 +28,14 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, logging
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
@ -102,55 +105,6 @@ class SamHQImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
class SamHQPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
return embeddings
class SamHQMLPBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
class SamHQVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
@ -281,13 +235,56 @@ class SamHQVisionAttention(nn.Module):
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
return attn_output, attn_weights
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
class SamHQPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
return embeddings
class SamHQMLPBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
class SamHQVisionSdpaAttention(SamHQVisionAttention):
@ -345,7 +342,6 @@ class SamHQVisionSdpaAttention(SamHQVisionAttention):
)
attn_output = self.proj(attn_output)
return attn_output, None
@ -418,13 +414,8 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
# Window partition
if self.window_size > 0:
@ -433,7 +424,6 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
)
# Reverse window partition
if self.window_size > 0:
@ -442,12 +432,7 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
return hidden_states
class SamHQVisionNeck(nn.Module):
@ -470,9 +455,11 @@ class SamHQVisionNeck(nn.Module):
return hidden_states
class SamHQVisionEncoder(nn.Module):
class SamHQVisionEncoder(PreTrainedModel):
_can_record_outputs = {"attentions": (SamHQVisionAttention, 1)}
def __init__(self, config: SamHQVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
self.image_size = config.image_size
@ -505,20 +492,10 @@ class SamHQVisionEncoder(nn.Module):
def get_input_embeddings(self):
return self.patch_embed
@can_return_tuple
@check_model_inputs
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
) -> Union[tuple, SamHQVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
@ -526,41 +503,20 @@ class SamHQVisionEncoder(nn.Module):
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
intermediate_embeddings = []
for layer_module in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
hidden_states = layer_module(hidden_states)
# Collect embeddings from non-windowed blocks
if hasattr(layer_module, "window_size") and layer_module.window_size == 0:
intermediate_embeddings.append(hidden_states)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states, intermediate_embeddings)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return SamHQVisionEncoderOutput(
last_hidden_state=hidden_states,
intermediate_embeddings=intermediate_embeddings,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@ -746,7 +702,6 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
if self.skip_first_layer_pe:
@ -781,15 +736,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
return query, keys, attn_out
class SamHQTwoWayTransformer(nn.Module):
@ -813,16 +760,7 @@ class SamHQTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
@ -838,18 +776,13 @@ class SamHQTwoWayTransformer(nn.Module):
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer(
queries, keys, _ = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
@ -858,7 +791,7 @@ class SamHQTwoWayTransformer(nn.Module):
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
return queries, keys
class SamHQFeedForward(nn.Module):
@ -940,9 +873,9 @@ class SamHQMaskDecoder(nn.Module):
multimask_output: bool,
hq_token_only: bool,
intermediate_embeddings: Optional[list[torch.Tensor]] = None,
output_attentions: Optional[bool] = None,
attention_similarity: Optional[torch.Tensor] = None,
target_embedding: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict high-quality masks given image and prompt embeddings.
@ -1004,18 +937,17 @@ class SamHQMaskDecoder(nn.Module):
else:
tokens = output_tokens
point_embeddings = tokens.to(self.iou_token.weight.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
point_embedding, image_embeddings, attentions = self.transformer(
point_embedding, iou_token_out = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
**kwargs,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
@ -1089,11 +1021,6 @@ class SamHQMaskDecoder(nn.Module):
masks = masks_sam + masks_hq
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
@ -1140,8 +1067,6 @@ class SamHQVisionModel(SamHQPreTrainedModel):
def __init__(self, config: SamHQVisionConfig):
super().__init__(config)
self.vision_encoder = SamHQVisionEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
@ -1151,16 +1076,9 @@ class SamHQVisionModel(SamHQPreTrainedModel):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, SamHQVisionEncoderOutput]:
return self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return self.vision_encoder(pixel_values, **kwargs)
class SamHQPositionalEmbedding(nn.Module):
@ -1371,9 +1289,7 @@ class SamHQModel(SamHQPreTrainedModel):
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
@ -1391,9 +1307,6 @@ class SamHQModel(SamHQPreTrainedModel):
"""
vision_output = self.vision_encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
intermediate_embeddings = vision_output[1]
@ -1434,7 +1347,6 @@ class SamHQModel(SamHQPreTrainedModel):
return prompt_output
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
@ -1447,11 +1359,8 @@ class SamHQModel(SamHQPreTrainedModel):
hq_token_only: bool = False,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
intermediate_embeddings: Optional[list[torch.FloatTensor]] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> list[dict[str, torch.Tensor]]:
r"""
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
@ -1540,12 +1449,6 @@ class SamHQModel(SamHQPreTrainedModel):
... )
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -1578,32 +1481,10 @@ class SamHQModel(SamHQPreTrainedModel):
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if return_dict:
image_embeddings = vision_outputs.last_hidden_state
intermediate_embeddings = vision_outputs.intermediate_embeddings
if output_hidden_states:
vision_hidden_states = vision_outputs.hidden_states
if output_attentions:
vision_attentions = vision_outputs.attentions
else:
image_embeddings = vision_outputs[0]
intermediate_embeddings = vision_outputs[1]
if output_hidden_states:
vision_hidden_states = vision_outputs[2]
if output_attentions:
vision_attentions = vision_outputs[-1]
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
image_embeddings = vision_outputs.last_hidden_state
intermediate_embeddings = vision_outputs.intermediate_embeddings
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
@ -1615,7 +1496,7 @@ class SamHQModel(SamHQPreTrainedModel):
)
# Predict masks
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
@ -1625,24 +1506,12 @@ class SamHQModel(SamHQPreTrainedModel):
intermediate_embeddings=intermediate_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
if not return_dict:
output = (iou_predictions, low_res_masks)
if output_hidden_states:
output = output + (vision_hidden_states,)
if output_attentions:
output = output + (vision_attentions, mask_decoder_attentions)
return output
return SamHQImageSegmentationOutput(
iou_scores=iou_predictions,
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
vision_hidden_states=vision_outputs.hidden_states,
vision_attentions=vision_outputs.attentions,
)

View File

@ -16,9 +16,11 @@
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
from ...processing_utils import Unpack
from ...utils import auto_docstring, logging
from ..sam.configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from ..sam.modeling_sam import (
@ -28,6 +30,7 @@ from ..sam.modeling_sam import (
SamModel,
SamPreTrainedModel,
SamTwoWayTransformer,
SamVisionAttention,
SamVisionEncoder,
SamVisionEncoderOutput,
SamVisionModel,
@ -125,20 +128,17 @@ class SamHQImageSegmentationOutput(SamImageSegmentationOutput):
pass
class SamHQVisionEncoder(SamVisionEncoder):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SamHQVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
class SamHQVisionAttention(SamVisionAttention):
pass
class SamHQVisionEncoder(SamVisionEncoder):
_can_record_outputs = {"attentions": (SamHQVisionAttention, 1)}
@check_model_inputs
def forward(
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
) -> Union[tuple, SamHQVisionEncoderOutput]:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
@ -146,41 +146,20 @@ class SamHQVisionEncoder(SamVisionEncoder):
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
intermediate_embeddings = []
for layer_module in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
hidden_states = layer_module(hidden_states)
# Collect embeddings from non-windowed blocks
if hasattr(layer_module, "window_size") and layer_module.window_size == 0:
intermediate_embeddings.append(hidden_states)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states, intermediate_embeddings)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return SamHQVisionEncoderOutput(
last_hidden_state=hidden_states,
intermediate_embeddings=intermediate_embeddings,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@ -251,9 +230,9 @@ class SamHQMaskDecoder(nn.Module):
multimask_output: bool,
hq_token_only: bool,
intermediate_embeddings: Optional[list[torch.Tensor]] = None,
output_attentions: Optional[bool] = None,
attention_similarity: Optional[torch.Tensor] = None,
target_embedding: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict high-quality masks given image and prompt embeddings.
@ -315,18 +294,17 @@ class SamHQMaskDecoder(nn.Module):
else:
tokens = output_tokens
point_embeddings = tokens.to(self.iou_token.weight.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
point_embedding, image_embeddings, attentions = self.transformer(
point_embedding, iou_token_out = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
**kwargs,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
@ -400,11 +378,6 @@ class SamHQMaskDecoder(nn.Module):
masks = masks_sam + masks_hq
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
@ -442,9 +415,7 @@ class SamHQModel(SamModel):
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
@ -462,15 +433,13 @@ class SamHQModel(SamModel):
"""
vision_output = self.vision_encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
intermediate_embeddings = vision_output[1]
return image_embeddings, intermediate_embeddings
@can_return_tuple
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
@ -483,11 +452,8 @@ class SamHQModel(SamModel):
hq_token_only: bool = False,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
intermediate_embeddings: Optional[list[torch.FloatTensor]] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> list[dict[str, torch.Tensor]]:
r"""
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
@ -576,12 +542,6 @@ class SamHQModel(SamModel):
... )
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
@ -614,32 +574,10 @@ class SamHQModel(SamModel):
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if return_dict:
image_embeddings = vision_outputs.last_hidden_state
intermediate_embeddings = vision_outputs.intermediate_embeddings
if output_hidden_states:
vision_hidden_states = vision_outputs.hidden_states
if output_attentions:
vision_attentions = vision_outputs.attentions
else:
image_embeddings = vision_outputs[0]
intermediate_embeddings = vision_outputs[1]
if output_hidden_states:
vision_hidden_states = vision_outputs[2]
if output_attentions:
vision_attentions = vision_outputs[-1]
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
image_embeddings = vision_outputs.last_hidden_state
intermediate_embeddings = vision_outputs.intermediate_embeddings
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
@ -651,7 +589,7 @@ class SamHQModel(SamModel):
)
# Predict masks
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
@ -661,24 +599,12 @@ class SamHQModel(SamModel):
intermediate_embeddings=intermediate_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
if not return_dict:
output = (iou_predictions, low_res_masks)
if output_hidden_states:
output = output + (vision_hidden_states,)
if output_attentions:
output = output + (vision_attentions, mask_decoder_attentions)
return output
return SamHQImageSegmentationOutput(
iou_scores=iou_predictions,
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
vision_hidden_states=vision_outputs.hidden_states,
vision_attentions=vision_outputs.attentions,
)

View File

@ -982,7 +982,7 @@ def check_model_inputs(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache", self.config.use_cache)
use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False))
return_dict = self.config.return_dict if hasattr(self, "config") else True
return_dict = kwargs.pop("return_dict", return_dict)
if return_dict is None:
@ -993,7 +993,7 @@ def check_model_inputs(func):
bound.apply_defaults()
all_args = bound.arguments
if self.gradient_checkpointing and self.training and use_cache:
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
@ -1042,7 +1042,10 @@ def check_model_inputs(func):
h.remove()
for key in collected_outputs:
if key == "hidden_states":
collected_outputs[key] += (outputs.last_hidden_state,)
if hasattr(outputs, "vision_hidden_states"):
collected_outputs[key] += (outputs.vision_hidden_states,)
else:
collected_outputs[key] += (outputs.last_hidden_state,)
outputs[key] = collected_outputs[key]
elif key == "attentions":
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2: