mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
updates
This commit is contained in:
parent
17cf5424b0
commit
c4d43c5324
@ -21,13 +21,15 @@ from typing import Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
from transformers.utils.generic import TransformersKwargs, check_model_inputs
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
@ -329,7 +331,6 @@ class SamTwoWayAttentionBlock(nn.Module):
|
||||
query_point_embedding: Tensor,
|
||||
key_point_embedding: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
@ -364,15 +365,7 @@ class SamTwoWayAttentionBlock(nn.Module):
|
||||
keys = keys + attn_out
|
||||
|
||||
keys = self.layer_norm4(keys)
|
||||
|
||||
outputs = (queries, keys)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn_out,)
|
||||
else:
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
return query, keys, attn_out
|
||||
|
||||
|
||||
class SamTwoWayTransformer(nn.Module):
|
||||
@ -396,16 +389,7 @@ class SamTwoWayTransformer(nn.Module):
|
||||
image_positional_embeddings: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
target_embedding=None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
all_attentions = ()
|
||||
|
||||
if image_embeddings is None:
|
||||
raise ValueError("You have to specify an image_embedding")
|
||||
|
||||
@ -421,18 +405,13 @@ class SamTwoWayTransformer(nn.Module):
|
||||
if target_embedding is not None:
|
||||
queries += target_embedding
|
||||
|
||||
queries, keys, attention_outputs = layer(
|
||||
queries, keys, _ = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_point_embedding=point_embeddings,
|
||||
key_point_embedding=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (attention_outputs,)
|
||||
|
||||
# Apply the final attenion layer from the points to the image
|
||||
query = queries + point_embeddings
|
||||
key = keys + image_positional_embeddings
|
||||
@ -441,7 +420,7 @@ class SamTwoWayTransformer(nn.Module):
|
||||
|
||||
queries = queries + attn_out
|
||||
queries = self.layer_norm_final_attn(queries)
|
||||
return queries, keys, all_attentions
|
||||
return queries, keys
|
||||
|
||||
|
||||
class SamFeedForward(nn.Module):
|
||||
@ -468,9 +447,11 @@ class SamFeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamMaskDecoder(nn.Module):
|
||||
class SamMaskDecoder(PreTrainedModel):
|
||||
_can_return_tuple = {"attentions": (SamTwoWayAttentionBlock, 2)}
|
||||
|
||||
def __init__(self, config: SamMaskDecoderConfig):
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@ -504,9 +485,9 @@ class SamMaskDecoder(nn.Module):
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
output_attentions: Optional[bool] = None,
|
||||
attention_similarity: Optional[torch.Tensor] = None,
|
||||
target_embedding: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
@ -522,8 +503,6 @@ class SamMaskDecoder(nn.Module):
|
||||
the embeddings of the mask inputs
|
||||
multimask_output (bool):
|
||||
Whether to return multiple masks or a single mask.
|
||||
output_attentions (bool, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
"""
|
||||
batch_size, num_channels, height, width = image_embeddings.shape
|
||||
point_batch_size = sparse_prompt_embeddings.shape[1]
|
||||
@ -543,13 +522,12 @@ class SamMaskDecoder(nn.Module):
|
||||
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
|
||||
# Run the transformer, image_positional_embedding are consumed
|
||||
point_embedding, image_embeddings, attentions = self.transformer(
|
||||
point_embedding, image_embeddings = self.transformer(
|
||||
point_embeddings=point_embeddings,
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
iou_token_out = point_embedding[:, :, 0, :]
|
||||
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
||||
@ -583,15 +561,7 @@ class SamMaskDecoder(nn.Module):
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, :, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, :, mask_slice]
|
||||
|
||||
outputs = (masks, iou_pred)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attentions,)
|
||||
else:
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class SamPositionalEmbedding(nn.Module):
|
||||
@ -887,13 +857,7 @@ class SamVisionAttention(nn.Module):
|
||||
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
outputs = (attn_output, attn_weights)
|
||||
else:
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SamVisionSdpaAttention(SamVisionAttention):
|
||||
@ -951,7 +915,6 @@ class SamVisionSdpaAttention(SamVisionAttention):
|
||||
)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@ -1024,13 +987,8 @@ class SamVisionLayer(GradientCheckpointingLayer):
|
||||
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.FloatTensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
@ -1039,7 +997,6 @@ class SamVisionLayer(GradientCheckpointingLayer):
|
||||
|
||||
hidden_states, attn_weights = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
@ -1048,12 +1005,7 @@ class SamVisionLayer(GradientCheckpointingLayer):
|
||||
hidden_states = residual + hidden_states
|
||||
layernorm_output = self.layer_norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(layernorm_output)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamVisionNeck(nn.Module):
|
||||
@ -1076,9 +1028,11 @@ class SamVisionNeck(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamVisionEncoder(nn.Module):
|
||||
class SamVisionEncoder(PreTrainedModel):
|
||||
_can_record_outputs = {"hidden_states": (SamVisionLayer, 0), "attentions": (SamVisionAttention, 0)}
|
||||
|
||||
def __init__(self, config: SamVisionConfig):
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.image_size = config.image_size
|
||||
|
||||
@ -1111,49 +1065,21 @@ class SamVisionEncoder(nn.Module):
|
||||
def get_input_embeddings(self):
|
||||
return self.patch_embed
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
|
||||
) -> SamVisionEncoderOutput:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
if self.pos_embed is not None:
|
||||
hidden_states = hidden_states + self.pos_embed
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
for layer_module in self.layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
hidden_states = self.neck(hidden_states)
|
||||
|
||||
return SamVisionEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1197,8 +1123,6 @@ class SamVisionModel(SamPreTrainedModel):
|
||||
def __init__(self, config: SamVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_encoder = SamVisionEncoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
@ -1208,16 +1132,9 @@ class SamVisionModel(SamPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, SamVisionEncoderOutput]:
|
||||
return self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
return self.vision_encoder(pixel_values, **kwargs)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -1262,12 +1179,7 @@ class SamModel(SamPreTrainedModel):
|
||||
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embeddings(
|
||||
self,
|
||||
pixel_values,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
):
|
||||
def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
|
||||
r"""
|
||||
Returns the image embeddings by passing the pixel values through the vision encoder.
|
||||
|
||||
@ -1281,8 +1193,7 @@ class SamModel(SamPreTrainedModel):
|
||||
"""
|
||||
vision_output = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
image_embeddings = vision_output[0]
|
||||
return image_embeddings
|
||||
@ -1333,9 +1244,7 @@ class SamModel(SamPreTrainedModel):
|
||||
multimask_output: bool = True,
|
||||
attention_similarity: Optional[torch.FloatTensor] = None,
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SamImageSegmentationOutput:
|
||||
r"""
|
||||
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
||||
@ -1415,11 +1324,6 @@ class SamModel(SamPreTrainedModel):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
||||
|
||||
@ -1453,17 +1357,10 @@ class SamModel(SamPreTrainedModel):
|
||||
vision_hidden_states = None
|
||||
|
||||
if pixel_values is not None:
|
||||
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
|
||||
image_embeddings = vision_outputs.last_hidden_state
|
||||
|
||||
if output_hidden_states:
|
||||
vision_hidden_states = vision_outputs.hidden_states
|
||||
if output_attentions:
|
||||
vision_attentions = vision_outputs.attentions
|
||||
vision_hidden_states = vision_outputs.hidden_states
|
||||
vision_attentions = vision_outputs.attentions
|
||||
|
||||
if input_points is not None and input_labels is None:
|
||||
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
||||
@ -1484,7 +1381,7 @@ class SamModel(SamPreTrainedModel):
|
||||
input_masks=input_masks,
|
||||
)
|
||||
|
||||
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
@ -1492,7 +1389,7 @@ class SamModel(SamPreTrainedModel):
|
||||
multimask_output=multimask_output,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return SamImageSegmentationOutput(
|
||||
@ -1500,7 +1397,6 @@ class SamModel(SamPreTrainedModel):
|
||||
pred_masks=low_res_masks,
|
||||
vision_hidden_states=vision_hidden_states,
|
||||
vision_attentions=vision_attentions,
|
||||
mask_decoder_attentions=mask_decoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -28,11 +28,14 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
|
||||
|
||||
|
||||
@ -102,55 +105,6 @@ class SamHQImageSegmentationOutput(ModelOutput):
|
||||
mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
class SamHQPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||
Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SamHQMLPBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
|
||||
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.lin1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.lin2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamHQVisionAttention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
@ -281,13 +235,56 @@ class SamHQVisionAttention(nn.Module):
|
||||
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
if output_attentions:
|
||||
outputs = (attn_output, attn_weights)
|
||||
else:
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
class SamHQPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||
Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SamHQMLPBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
|
||||
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.lin1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.lin2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamHQVisionSdpaAttention(SamHQVisionAttention):
|
||||
@ -345,7 +342,6 @@ class SamHQVisionSdpaAttention(SamHQVisionAttention):
|
||||
)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@ -418,13 +414,8 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
|
||||
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.FloatTensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
@ -433,7 +424,6 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
|
||||
|
||||
hidden_states, attn_weights = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
@ -442,12 +432,7 @@ class SamHQVisionLayer(GradientCheckpointingLayer):
|
||||
hidden_states = residual + hidden_states
|
||||
layernorm_output = self.layer_norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(layernorm_output)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamHQVisionNeck(nn.Module):
|
||||
@ -470,9 +455,11 @@ class SamHQVisionNeck(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SamHQVisionEncoder(nn.Module):
|
||||
class SamHQVisionEncoder(PreTrainedModel):
|
||||
_can_record_outputs = {"attentions": (SamHQVisionAttention, 1)}
|
||||
|
||||
def __init__(self, config: SamHQVisionConfig):
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.image_size = config.image_size
|
||||
|
||||
@ -505,20 +492,10 @@ class SamHQVisionEncoder(nn.Module):
|
||||
def get_input_embeddings(self):
|
||||
return self.patch_embed
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
|
||||
) -> Union[tuple, SamHQVisionEncoderOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
@ -526,41 +503,20 @@ class SamHQVisionEncoder(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
hidden_states = hidden_states + self.pos_embed
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
intermediate_embeddings = []
|
||||
|
||||
for layer_module in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states = layer_outputs[0]
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
# Collect embeddings from non-windowed blocks
|
||||
if hasattr(layer_module, "window_size") and layer_module.window_size == 0:
|
||||
intermediate_embeddings.append(hidden_states)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states = self.neck(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (hidden_states, intermediate_embeddings)
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_self_attentions,)
|
||||
return outputs
|
||||
|
||||
return SamHQVisionEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
intermediate_embeddings=intermediate_embeddings,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -746,7 +702,6 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
||||
query_point_embedding: Tensor,
|
||||
key_point_embedding: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
@ -781,15 +736,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
||||
keys = keys + attn_out
|
||||
|
||||
keys = self.layer_norm4(keys)
|
||||
|
||||
outputs = (queries, keys)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn_out,)
|
||||
else:
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
return query, keys, attn_out
|
||||
|
||||
|
||||
class SamHQTwoWayTransformer(nn.Module):
|
||||
@ -813,16 +760,7 @@ class SamHQTwoWayTransformer(nn.Module):
|
||||
image_positional_embeddings: Tensor,
|
||||
attention_similarity: Tensor,
|
||||
target_embedding=None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
all_attentions = ()
|
||||
|
||||
if image_embeddings is None:
|
||||
raise ValueError("You have to specify an image_embedding")
|
||||
|
||||
@ -838,18 +776,13 @@ class SamHQTwoWayTransformer(nn.Module):
|
||||
if target_embedding is not None:
|
||||
queries += target_embedding
|
||||
|
||||
queries, keys, attention_outputs = layer(
|
||||
queries, keys, _ = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_point_embedding=point_embeddings,
|
||||
key_point_embedding=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (attention_outputs,)
|
||||
|
||||
# Apply the final attenion layer from the points to the image
|
||||
query = queries + point_embeddings
|
||||
key = keys + image_positional_embeddings
|
||||
@ -858,7 +791,7 @@ class SamHQTwoWayTransformer(nn.Module):
|
||||
|
||||
queries = queries + attn_out
|
||||
queries = self.layer_norm_final_attn(queries)
|
||||
return queries, keys, all_attentions
|
||||
return queries, keys
|
||||
|
||||
|
||||
class SamHQFeedForward(nn.Module):
|
||||
@ -940,9 +873,9 @@ class SamHQMaskDecoder(nn.Module):
|
||||
multimask_output: bool,
|
||||
hq_token_only: bool,
|
||||
intermediate_embeddings: Optional[list[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
attention_similarity: Optional[torch.Tensor] = None,
|
||||
target_embedding: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict high-quality masks given image and prompt embeddings.
|
||||
@ -1004,18 +937,17 @@ class SamHQMaskDecoder(nn.Module):
|
||||
else:
|
||||
tokens = output_tokens
|
||||
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
||||
|
||||
image_embeddings = image_embeddings + dense_prompt_embeddings
|
||||
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
|
||||
point_embedding, image_embeddings, attentions = self.transformer(
|
||||
point_embedding, iou_token_out = self.transformer(
|
||||
point_embeddings=point_embeddings,
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
iou_token_out = point_embedding[:, :, 0, :]
|
||||
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
||||
@ -1089,11 +1021,6 @@ class SamHQMaskDecoder(nn.Module):
|
||||
masks = masks_sam + masks_hq
|
||||
|
||||
outputs = (masks, iou_pred)
|
||||
if output_attentions:
|
||||
outputs = outputs + (attentions,)
|
||||
else:
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1140,8 +1067,6 @@ class SamHQVisionModel(SamHQPreTrainedModel):
|
||||
def __init__(self, config: SamHQVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_encoder = SamHQVisionEncoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
@ -1151,16 +1076,9 @@ class SamHQVisionModel(SamHQPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, SamHQVisionEncoderOutput]:
|
||||
return self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
return self.vision_encoder(pixel_values, **kwargs)
|
||||
|
||||
|
||||
class SamHQPositionalEmbedding(nn.Module):
|
||||
@ -1371,9 +1289,7 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
def get_image_embeddings(
|
||||
self,
|
||||
pixel_values,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
r"""
|
||||
Returns the image embeddings by passing the pixel values through the vision encoder.
|
||||
@ -1391,9 +1307,6 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
"""
|
||||
vision_output = self.vision_encoder(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
image_embeddings = vision_output[0]
|
||||
intermediate_embeddings = vision_output[1]
|
||||
@ -1434,7 +1347,6 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
return prompt_output
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
@ -1447,11 +1359,8 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
hq_token_only: bool = False,
|
||||
attention_similarity: Optional[torch.FloatTensor] = None,
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
intermediate_embeddings: Optional[list[torch.FloatTensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
||||
@ -1540,12 +1449,6 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
||||
|
||||
@ -1578,32 +1481,10 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
|
||||
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
||||
|
||||
vision_attentions = None
|
||||
vision_hidden_states = None
|
||||
|
||||
if pixel_values is not None:
|
||||
vision_outputs = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if return_dict:
|
||||
image_embeddings = vision_outputs.last_hidden_state
|
||||
intermediate_embeddings = vision_outputs.intermediate_embeddings
|
||||
if output_hidden_states:
|
||||
vision_hidden_states = vision_outputs.hidden_states
|
||||
if output_attentions:
|
||||
vision_attentions = vision_outputs.attentions
|
||||
else:
|
||||
image_embeddings = vision_outputs[0]
|
||||
intermediate_embeddings = vision_outputs[1]
|
||||
if output_hidden_states:
|
||||
vision_hidden_states = vision_outputs[2]
|
||||
if output_attentions:
|
||||
vision_attentions = vision_outputs[-1]
|
||||
|
||||
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
|
||||
image_embeddings = vision_outputs.last_hidden_state
|
||||
intermediate_embeddings = vision_outputs.intermediate_embeddings
|
||||
if input_points is not None and input_labels is None:
|
||||
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
||||
|
||||
@ -1615,7 +1496,7 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
@ -1625,24 +1506,12 @@ class SamHQModel(SamHQPreTrainedModel):
|
||||
intermediate_embeddings=intermediate_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (iou_predictions, low_res_masks)
|
||||
if output_hidden_states:
|
||||
output = output + (vision_hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
output = output + (vision_attentions, mask_decoder_attentions)
|
||||
return output
|
||||
|
||||
return SamHQImageSegmentationOutput(
|
||||
iou_scores=iou_predictions,
|
||||
pred_masks=low_res_masks,
|
||||
vision_hidden_states=vision_hidden_states,
|
||||
vision_attentions=vision_attentions,
|
||||
mask_decoder_attentions=mask_decoder_attentions,
|
||||
vision_hidden_states=vision_outputs.hidden_states,
|
||||
vision_attentions=vision_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -16,9 +16,11 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
|
||||
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
from ..sam.configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||
from ..sam.modeling_sam import (
|
||||
@ -28,6 +30,7 @@ from ..sam.modeling_sam import (
|
||||
SamModel,
|
||||
SamPreTrainedModel,
|
||||
SamTwoWayTransformer,
|
||||
SamVisionAttention,
|
||||
SamVisionEncoder,
|
||||
SamVisionEncoderOutput,
|
||||
SamVisionModel,
|
||||
@ -125,20 +128,17 @@ class SamHQImageSegmentationOutput(SamImageSegmentationOutput):
|
||||
pass
|
||||
|
||||
|
||||
class SamHQVisionEncoder(SamVisionEncoder):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, SamHQVisionEncoderOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
class SamHQVisionAttention(SamVisionAttention):
|
||||
pass
|
||||
|
||||
|
||||
class SamHQVisionEncoder(SamVisionEncoder):
|
||||
_can_record_outputs = {"attentions": (SamHQVisionAttention, 1)}
|
||||
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
|
||||
) -> Union[tuple, SamHQVisionEncoderOutput]:
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
@ -146,41 +146,20 @@ class SamHQVisionEncoder(SamVisionEncoder):
|
||||
if self.pos_embed is not None:
|
||||
hidden_states = hidden_states + self.pos_embed
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
intermediate_embeddings = []
|
||||
|
||||
for layer_module in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states = layer_outputs[0]
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
# Collect embeddings from non-windowed blocks
|
||||
if hasattr(layer_module, "window_size") and layer_module.window_size == 0:
|
||||
intermediate_embeddings.append(hidden_states)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states = self.neck(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (hidden_states, intermediate_embeddings)
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_self_attentions,)
|
||||
return outputs
|
||||
|
||||
return SamHQVisionEncoderOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
intermediate_embeddings=intermediate_embeddings,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -251,9 +230,9 @@ class SamHQMaskDecoder(nn.Module):
|
||||
multimask_output: bool,
|
||||
hq_token_only: bool,
|
||||
intermediate_embeddings: Optional[list[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
attention_similarity: Optional[torch.Tensor] = None,
|
||||
target_embedding: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict high-quality masks given image and prompt embeddings.
|
||||
@ -315,18 +294,17 @@ class SamHQMaskDecoder(nn.Module):
|
||||
else:
|
||||
tokens = output_tokens
|
||||
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
||||
|
||||
image_embeddings = image_embeddings + dense_prompt_embeddings
|
||||
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
|
||||
point_embedding, image_embeddings, attentions = self.transformer(
|
||||
point_embedding, iou_token_out = self.transformer(
|
||||
point_embeddings=point_embeddings,
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
iou_token_out = point_embedding[:, :, 0, :]
|
||||
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
||||
@ -400,11 +378,6 @@ class SamHQMaskDecoder(nn.Module):
|
||||
masks = masks_sam + masks_hq
|
||||
|
||||
outputs = (masks, iou_pred)
|
||||
if output_attentions:
|
||||
outputs = outputs + (attentions,)
|
||||
else:
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -442,9 +415,7 @@ class SamHQModel(SamModel):
|
||||
def get_image_embeddings(
|
||||
self,
|
||||
pixel_values,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
r"""
|
||||
Returns the image embeddings by passing the pixel values through the vision encoder.
|
||||
@ -462,15 +433,13 @@ class SamHQModel(SamModel):
|
||||
"""
|
||||
vision_output = self.vision_encoder(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
image_embeddings = vision_output[0]
|
||||
intermediate_embeddings = vision_output[1]
|
||||
|
||||
return image_embeddings, intermediate_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
@ -483,11 +452,8 @@ class SamHQModel(SamModel):
|
||||
hq_token_only: bool = False,
|
||||
attention_similarity: Optional[torch.FloatTensor] = None,
|
||||
target_embedding: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
intermediate_embeddings: Optional[list[torch.FloatTensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
||||
@ -576,12 +542,6 @@ class SamHQModel(SamModel):
|
||||
... )
|
||||
```
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None and image_embeddings is None:
|
||||
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
||||
|
||||
@ -614,32 +574,10 @@ class SamHQModel(SamModel):
|
||||
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
|
||||
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
||||
|
||||
vision_attentions = None
|
||||
vision_hidden_states = None
|
||||
|
||||
if pixel_values is not None:
|
||||
vision_outputs = self.vision_encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if return_dict:
|
||||
image_embeddings = vision_outputs.last_hidden_state
|
||||
intermediate_embeddings = vision_outputs.intermediate_embeddings
|
||||
if output_hidden_states:
|
||||
vision_hidden_states = vision_outputs.hidden_states
|
||||
if output_attentions:
|
||||
vision_attentions = vision_outputs.attentions
|
||||
else:
|
||||
image_embeddings = vision_outputs[0]
|
||||
intermediate_embeddings = vision_outputs[1]
|
||||
if output_hidden_states:
|
||||
vision_hidden_states = vision_outputs[2]
|
||||
if output_attentions:
|
||||
vision_attentions = vision_outputs[-1]
|
||||
|
||||
vision_outputs = self.vision_encoder(pixel_values, **kwargs)
|
||||
image_embeddings = vision_outputs.last_hidden_state
|
||||
intermediate_embeddings = vision_outputs.intermediate_embeddings
|
||||
if input_points is not None and input_labels is None:
|
||||
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
||||
|
||||
@ -651,7 +589,7 @@ class SamHQModel(SamModel):
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=image_embeddings,
|
||||
image_positional_embeddings=image_positional_embeddings,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
@ -661,24 +599,12 @@ class SamHQModel(SamModel):
|
||||
intermediate_embeddings=intermediate_embeddings,
|
||||
attention_similarity=attention_similarity,
|
||||
target_embedding=target_embedding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (iou_predictions, low_res_masks)
|
||||
if output_hidden_states:
|
||||
output = output + (vision_hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
output = output + (vision_attentions, mask_decoder_attentions)
|
||||
return output
|
||||
|
||||
return SamHQImageSegmentationOutput(
|
||||
iou_scores=iou_predictions,
|
||||
pred_masks=low_res_masks,
|
||||
vision_hidden_states=vision_hidden_states,
|
||||
vision_attentions=vision_attentions,
|
||||
mask_decoder_attentions=mask_decoder_attentions,
|
||||
vision_hidden_states=vision_outputs.hidden_states,
|
||||
vision_attentions=vision_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -982,7 +982,7 @@ def check_model_inputs(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
use_cache = kwargs.get("use_cache", self.config.use_cache)
|
||||
use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False))
|
||||
return_dict = self.config.return_dict if hasattr(self, "config") else True
|
||||
return_dict = kwargs.pop("return_dict", return_dict)
|
||||
if return_dict is None:
|
||||
@ -993,7 +993,7 @@ def check_model_inputs(func):
|
||||
bound.apply_defaults()
|
||||
all_args = bound.arguments
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
@ -1042,7 +1042,10 @@ def check_model_inputs(func):
|
||||
h.remove()
|
||||
for key in collected_outputs:
|
||||
if key == "hidden_states":
|
||||
collected_outputs[key] += (outputs.last_hidden_state,)
|
||||
if hasattr(outputs, "vision_hidden_states"):
|
||||
collected_outputs[key] += (outputs.vision_hidden_states,)
|
||||
else:
|
||||
collected_outputs[key] += (outputs.last_hidden_state,)
|
||||
outputs[key] = collected_outputs[key]
|
||||
elif key == "attentions":
|
||||
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
|
||||
|
Loading…
Reference in New Issue
Block a user