Samhq model addition (#35147)

* added the configuartion for sam_hq

* added the modeelling for sam_hq

* added the sam hq mask decoder with hq features

* added the code for the samhq

* added the code for the samhq

* added the code for the samhq

* Delete src/transformers/models/sam_hq/modelling_sam_hq.py

* added the code for the samhq

* added the code for the samhq

* added the chnages for the modeelling

* added the code for sam hq for image processing

* added code for the sam hq model

* added the required changes

* added the changes

* added the key mappings for the sam hq

* adding the working code of samhq

* added the required files

* adding the pt object

* added the push to hub account

* added the args for the sam maks  decoder

* added the args for the sam hq vision config

* aded the some more documentation

* removed the unecessary spaces

* all required chnages

* removed the image processor

* added the required file

* added the changes for the checkcopies

* added the code for modular file

* added the changes for the __init file

* added the code for the interm embeds

* added the code for sam hq

* added the changes for modular file

* added the test file

* added the changes required

* added the changes required

* added the code for the

* added the cl errors

* added the changes

* added the required changes

* added the some code

* added the code for the removing image processor

* added the test dimensins

* added the code for the removing extra used variables

* added the code for modeluar file hf_mlp for a better name

* removed abbrevaation in core functionality

* removed abbrevaation in core functionality

* .contiguous() method is often used to ensure that the tensor is stored in a contiguous block of memory

* added the code which is after make fixup

* added some test for the intermediate embeddings test

* added the code for the torch support in sam hq

* added the code for the updated modular file

* added the changes for documentations as mentioned

* removed the heading

* add the changes for the code

* first mentioned issue resolved

* added the changes code to processor

* added the easy loading to init file

* added the changes to code

* added the code to changes

* added the code to work

* added the code for sam hq

* added the code for sam hq

* added the code for the point pad value

* added the small test for the image embeddings and intermediate embedding

* added the code

* added the code

* added the code for the tests

* added the code

* added ythe code for the processor file

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code for tests and some checks

* added some code

* added the code

* added the code

* added some code

* added some code

* added the changes for required

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code

* added the code

* added some changes

* added some changes

* removed spaces and quality checks

* added some code

* added some code

* added some code

* added code quality checks

* added the checks for quality checks

* addded some code which fixes test_inference_mask_generation_no_point

* added code for the test_inference_mask_generation_one_point_one_bb

* added code for the test_inference_mask_generation_one_point_one_bb_zero

* added code for the test_inference_mask_generation_one_box

* added some code in modelling for testing

* added some code which sort maks with high score

* added some code

* added some code

* added some code for the move KEYS_TO_MODIFY_MAPPING

* added some code for the  unsqueeze removal

* added some code for the  unsqueeze removal

* added some code

* added some code

* add some code

* added some code

* added some code

* added some testign values changed

* added changes to code in sam hq for readbility purpose

* added pre commit checks

* added the fix samvisionmodel for compatibilty

* added the changes made on sam by cyyever

* fixed the tests for samhq

* added some the code

* added some code related to init file issue during merge conflicts

* remobved the merge conflicts

* added changes mentioned by aruther and mobap

* added changes mentioned by aruther and mobap

* solving quality checks

* added the changes for input clearly

* added the changes

* added changes in mask generation file rgearding model inputs and  sam hq quargs  in processor file

* added changes in processor file

* added the  Setup -> setupclass conversion

* added the code mentioned for processor

* added changes for the code

* added some code

* added some code

* added some code

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
sushmanth reddy 2025-04-28 22:37:09 +05:30 committed by GitHub
parent 9c5b1319d0
commit 65e940208c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 4926 additions and 1 deletions

View File

@ -1017,6 +1017,8 @@
title: Qwen2VL
- local: model_doc/sam
title: Segment Anything
- local: model_doc/sam_hq
title: Segment Anything High Quality
- local: model_doc/shieldgemma2
title: ShieldGemma2
- local: model_doc/siglip

View File

@ -0,0 +1,127 @@
# SAM-HQ
## Overview
SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu.
The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability.
![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png)
SAM-HQ introduces several key improvements over the original SAM model:
1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction
2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details
3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B
4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality
5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy
The abstract from the paper is the following:
*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.*
Tips:
- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details
- The model predicts binary masks with more accurate boundaries and better handling of thin structures
- Like SAM, the model performs better with input 2D points and/or input bounding boxes
- You can prompt multiple points for the same image and predict a single high-quality mask
- The model maintains SAM's zero-shot generalization capabilities
- SAM-HQ only adds ~0.5% additional parameters compared to SAM
- Fine-tuning the model is not supported yet
This model was contributed by [sushmanth](https://huggingface.co/sushmanth).
The original code can be found [here](https://github.com/SysCV/SAM-HQ).
Below is an example on how to run mask generation given an image and a 2D point:
```python
import torch
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]] # 2D location of a window in the image
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
```
You can also process your own masks alongside the input images in the processor to be passed to the model:
```python
import torch
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1")
input_points = [[[450, 600]]] # 2D location of a window in the image
inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ:
- Demo notebook for using the model (coming soon)
- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ)
## SamHQConfig
[[autodoc]] SamHQConfig
## SamHQVisionConfig
[[autodoc]] SamHQVisionConfig
## SamHQMaskDecoderConfig
[[autodoc]] SamHQMaskDecoderConfig
## SamHQPromptEncoderConfig
[[autodoc]] SamHQPromptEncoderConfig
## SamHQProcessor
[[autodoc]] SamHQProcessor
## SamHQVisionModel
[[autodoc]] SamHQVisionModel
## SamHQModel
[[autodoc]] SamHQModel
- forward

View File

@ -254,6 +254,7 @@ if TYPE_CHECKING:
from .rt_detr_v2 import *
from .rwkv import *
from .sam import *
from .sam_hq import *
from .seamless_m4t import *
from .seamless_m4t_v2 import *
from .segformer import *

View File

@ -286,6 +286,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("rt_detr_v2", "RTDetrV2Config"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("sam_hq", "SamHQConfig"),
("sam_hq_vision_model", "SamHQVisionConfig"),
("sam_vision_model", "SamVisionConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
@ -658,6 +660,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("rt_detr_v2", "RT-DETRv2"),
("rwkv", "RWKV"),
("sam", "SAM"),
("sam_hq", "SAM-HQ"),
("sam_hq_vision_model", "SamHQVisionModel"),
("sam_vision_model", "SamVisionModel"),
("seamless_m4t", "SeamlessM4T"),
("seamless_m4t_v2", "SeamlessM4Tv2"),
@ -807,6 +811,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"),
("sam_hq_vision_model", "sam_hq"),
("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"),
]

View File

@ -141,6 +141,7 @@ else:
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor",)),
("sam_hq", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),

View File

@ -257,6 +257,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
("rt_detr_v2", "RTDetrV2Model"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("sam_hq", "SamHQModel"),
("sam_hq_vision_model", "SamHQVisionModel"),
("sam_vision_model", "SamVisionModel"),
("seamless_m4t", "SeamlessM4TModel"),
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
@ -1495,6 +1497,12 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam_hq", "SamHQModel"),
]
)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[

View File

@ -104,6 +104,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
("sam", "SamProcessor"),
("sam_hq", "SamHQProcessor"),
("seamless_m4t", "SeamlessM4TProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),

View File

@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_sam_hq import *
from .modeling_sam_hq import *
from .processing_samhq import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,315 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/sam_hq/modular_sam_hq.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_sam_hq.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class SamHQPromptEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SamHQPromptEncoderModel`].The [`SamHQPromptEncoderModel`]
module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield a
similar configuration to that of the SAM_HQ model. The configuration is used to store the configuration of the model.
[Uminosachi/sam-hq](https://huggingface.co/Uminosachi/sam-hq) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model's output.Read the documentation from
[`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 256):
Dimensionality of the hidden states.
image_size (`int`, *optional*, defaults to 1024):
The expected output resolution of the image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
mask_input_channels (`int`, *optional*, defaults to 16):
The number of channels to be fed to the `MaskDecoder` module.
num_point_embeddings (`int`, *optional*, defaults to 4):
The number of point embeddings to be used.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function in the encoder and pooler.
"""
base_config_key = "prompt_encoder_config"
def __init__(
self,
hidden_size=256,
image_size=1024,
patch_size=16,
mask_input_channels=16,
num_point_embeddings=4,
hidden_act="gelu",
layer_norm_eps=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.image_embedding_size = image_size // patch_size
self.mask_input_channels = mask_input_channels
self.num_point_embeddings = num_point_embeddings
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
class SamHQVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SamHQVisionModel`]. It is used to instantiate a SAM_HQ
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
defaults will yield a similar configuration to that of the SAM_HQ ViT-h
[facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
output_channels (`int`, *optional*, defaults to 256):
Dimensionality of the output channels in the Patch Encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input image.
image_size (`int`, *optional*, defaults to 1024):
Expected resolution. Target size of the resized input image.
patch_size (`int`, *optional*, defaults to 16):
Size of the patches to be extracted from the input image.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string)
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 1e-10):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to query, key, value projections.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of mlp hidden dim to embedding dim.
use_abs_pos (`bool`, *optional*, defaults to `True`):
Whether to use absolute position embedding.
use_rel_pos (`bool`, *optional*, defaults to `True`):
Whether to use relative position embedding.
window_size (`int`, *optional*, defaults to 14):
Window size for relative position.
global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
The indexes of the global attention layers.
num_pos_feats (`int`, *optional*, defaults to 128):
The dimensionality of the position embedding.
mlp_dim (`int`, *optional*):
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
hidden_size`.
Example:
```python
>>> from transformers import (
... SamHQVisionConfig,
... SamHQVisionModel,
... )
>>> # Initializing a SamHQVisionConfig with `"facebook/sam_hq-vit-huge"` style configuration
>>> configuration = SamHQVisionConfig()
>>> # Initializing a SamHQVisionModel (with random weights) from the `"facebook/sam_hq-vit-huge"` style configuration
>>> model = SamHQVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
base_config_key = "vision_config"
model_type = "sam_hq_vision_model"
def __init__(
self,
hidden_size=768,
output_channels=256,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=1024,
patch_size=16,
hidden_act="gelu",
layer_norm_eps=1e-06,
attention_dropout=0.0,
initializer_range=1e-10,
qkv_bias=True,
mlp_ratio=4.0,
use_abs_pos=True,
use_rel_pos=True,
window_size=14,
global_attn_indexes=[2, 5, 8, 11],
num_pos_feats=128,
mlp_dim=None,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.output_channels = output_channels
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.qkv_bias = qkv_bias
self.mlp_ratio = mlp_ratio
self.use_abs_pos = use_abs_pos
self.use_rel_pos = use_rel_pos
self.window_size = window_size
self.global_attn_indexes = global_attn_indexes
self.num_pos_feats = num_pos_feats
self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
class SamHQMaskDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SamHQMaskDecoder`]. It is used to instantiate a SAM_HQ
mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
will yield a similar configuration to that of the SAM_HQ-vit-h
[facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 256):
Dimensionality of the hidden states.
hidden_act (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function used inside the `SamHQMaskDecoder` module.
mlp_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 2):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
attention_downsample_rate (`int`, *optional*, defaults to 2):
The downsampling rate of the attention layer.
num_multimask_outputs (`int`, *optional*, defaults to 3):
The number of outputs from the `SamHQMaskDecoder` module. In the Segment Anything paper, this is set to 3.
iou_head_depth (`int`, *optional*, defaults to 3):
The number of layers in the IoU head module.
iou_head_hidden_dim (`int`, *optional*, defaults to 256):
The dimensionality of the hidden states in the IoU head module.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
vit_dim (`int`, *optional*, defaults to 768):
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
"""
base_config_key = "mask_decoder_config"
def __init__(
self,
hidden_size=256,
hidden_act="relu",
mlp_dim=2048,
num_hidden_layers=2,
num_attention_heads=8,
attention_downsample_rate=2,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=256,
layer_norm_eps=1e-6,
vit_dim=768,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_dim = mlp_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_downsample_rate = attention_downsample_rate
self.num_multimask_outputs = num_multimask_outputs
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
self.layer_norm_eps = layer_norm_eps
self.vit_dim = vit_dim
class SamHQConfig(PretrainedConfig):
r"""
[`SamHQConfig`] is the configuration class to store the configuration of a [`SamHQModel`]. It is used to instantiate a
SAM-HQ model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
SAM-HQ-ViT-H [sushmanth/sam_hq_vit_h](https://huggingface.co/sushmanth/sam_hq_vit_h) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (Union[`dict`, `SamHQVisionConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQVisionConfig`].
prompt_encoder_config (Union[`dict`, `SamHQPromptEncoderConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQPromptEncoderConfig`].
mask_decoder_config (Union[`dict`, `SamHQMaskDecoderConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQMaskDecoderConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
"""
model_type = "sam_hq"
sub_configs = {
"prompt_encoder_config": SamHQPromptEncoderConfig,
"mask_decoder_config": SamHQMaskDecoderConfig,
"vision_config": SamHQVisionConfig,
}
def __init__(
self,
vision_config=None,
prompt_encoder_config=None,
mask_decoder_config=None,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
vision_config = vision_config if vision_config is not None else {}
prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
if isinstance(vision_config, SamHQVisionConfig):
vision_config = vision_config.to_dict()
if isinstance(prompt_encoder_config, SamHQPromptEncoderConfig):
prompt_encoder_config = prompt_encoder_config.to_dict()
if isinstance(mask_decoder_config, SamHQMaskDecoderConfig):
mask_decoder_config = mask_decoder_config.to_dict()
self.vision_config = SamHQVisionConfig(**vision_config)
self.prompt_encoder_config = SamHQPromptEncoderConfig(**prompt_encoder_config)
self.mask_decoder_config = SamHQMaskDecoderConfig(**mask_decoder_config)
self.initializer_range = initializer_range
__all__ = ["SamHQVisionConfig", "SamHQMaskDecoderConfig", "SamHQPromptEncoderConfig", "SamHQConfig"]

View File

@ -0,0 +1,277 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert SAM-HQ checkpoints from the original repository.
URL: https://github.com/SysCV/sam-hq
"""
import argparse
import numpy as np
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import SamHQConfig, SamHQModel, SamHQProcessor, SamHQVisionConfig, SamImageProcessor
def get_config(model_name):
if "sam_hq_vit_b" in model_name:
vision_config = SamHQVisionConfig()
vit_dim = 768 # Base model dimension
elif "sam_hq_vit_l" in model_name:
vision_config = SamHQVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)
vit_dim = 1024 # Large model dimension
elif "sam_hq_vit_h" in model_name:
vision_config = SamHQVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)
vit_dim = 1280 # Huge model dimension
# Create mask decoder config with appropriate vit_dim
mask_decoder_config = {"vit_dim": vit_dim}
config = SamHQConfig(
vision_config=vision_config,
mask_decoder_config=mask_decoder_config,
)
return config
KEYS_TO_MODIFY_MAPPING = {
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
"iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
"mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
"mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
"mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
"mask_downscaling.0": "mask_embed.conv1",
"mask_downscaling.1": "mask_embed.layer_norm1",
"mask_downscaling.3": "mask_embed.conv2",
"mask_downscaling.4": "mask_embed.layer_norm2",
"mask_downscaling.6": "mask_embed.conv3",
"point_embeddings": "point_embed",
"pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
"image_encoder": "vision_encoder",
"neck.0": "neck.conv1",
"neck.1": "neck.layer_norm1",
"neck.2": "neck.conv2",
"neck.3": "neck.layer_norm2",
"patch_embed.proj": "patch_embed.projection",
".norm": ".layer_norm",
"blocks": "layers",
# HQ-specific mappings
"mask_decoder.hf_token": "mask_decoder.hq_token",
"mask_decoder.compress_vit_feat.0": "mask_decoder.compress_vit_conv1",
"mask_decoder.compress_vit_feat.1": "mask_decoder.compress_vit_norm",
"mask_decoder.compress_vit_feat.3": "mask_decoder.compress_vit_conv2",
"mask_decoder.embedding_encoder.0": "mask_decoder.encoder_conv1",
"mask_decoder.embedding_encoder.1": "mask_decoder.encoder_norm",
"mask_decoder.embedding_encoder.3": "mask_decoder.encoder_conv2",
"mask_decoder.embedding_maskfeature.0": "mask_decoder.mask_conv1",
"mask_decoder.embedding_maskfeature.1": "mask_decoder.mask_norm",
"mask_decoder.embedding_maskfeature.3": "mask_decoder.mask_conv2",
"mask_decoder.hf_mlp": "mask_decoder.hq_mask_mlp",
# Add patterns for the output_hypernetworks_mlps and hq_mask_mlp
"output_hypernetworks_mlps.0.layers.0": "output_hypernetworks_mlps.0.proj_in",
"output_hypernetworks_mlps.0.layers.1": "output_hypernetworks_mlps.0.layers.0",
"output_hypernetworks_mlps.0.layers.2": "output_hypernetworks_mlps.0.proj_out",
"output_hypernetworks_mlps.1.layers.0": "output_hypernetworks_mlps.1.proj_in",
"output_hypernetworks_mlps.1.layers.1": "output_hypernetworks_mlps.1.layers.0",
"output_hypernetworks_mlps.1.layers.2": "output_hypernetworks_mlps.1.proj_out",
"output_hypernetworks_mlps.2.layers.0": "output_hypernetworks_mlps.2.proj_in",
"output_hypernetworks_mlps.2.layers.1": "output_hypernetworks_mlps.2.layers.0",
"output_hypernetworks_mlps.2.layers.2": "output_hypernetworks_mlps.2.proj_out",
"output_hypernetworks_mlps.3.layers.0": "output_hypernetworks_mlps.3.proj_in",
"output_hypernetworks_mlps.3.layers.1": "output_hypernetworks_mlps.3.layers.0",
"output_hypernetworks_mlps.3.layers.2": "output_hypernetworks_mlps.3.proj_out",
"hq_mask_mlp.layers.0": "hq_mask_mlp.proj_in",
"hq_mask_mlp.layers.1": "hq_mask_mlp.layers.0",
"hq_mask_mlp.layers.2": "hq_mask_mlp.proj_out",
}
def replace_keys(state_dict):
model_state_dict = {}
state_dict.pop("pixel_mean", None)
state_dict.pop("pixel_std", None)
# Process each key in the state dict
for key, value in state_dict.items():
new_key = key
# Apply static mappings from KEYS_TO_MODIFY_MAPPING
for key_to_modify, replacement in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in new_key:
new_key = new_key.replace(key_to_modify, replacement)
model_state_dict[new_key] = value
# Add mapping for shared embedding for positional embedding
if "prompt_encoder.shared_embedding.positional_embedding" in model_state_dict:
model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
"prompt_encoder.shared_embedding.positional_embedding"
]
# Special handling for IOU prediction head keys
# Check if we're missing the expected keys and have the converted ones instead
if (
"mask_decoder.iou_prediction_head.layers.0.weight" not in model_state_dict
and "mask_decoder.iou_prediction_head.proj_in.weight" in model_state_dict
):
# Copy the converted key back to the expected format
model_state_dict["mask_decoder.iou_prediction_head.layers.0.weight"] = model_state_dict[
"mask_decoder.iou_prediction_head.proj_in.weight"
]
model_state_dict["mask_decoder.iou_prediction_head.layers.0.bias"] = model_state_dict[
"mask_decoder.iou_prediction_head.proj_in.bias"
]
return model_state_dict
def convert_sam_hq_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, hub_path):
config = get_config(model_name)
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
state_dict = replace_keys(state_dict)
image_processor = SamImageProcessor()
processor = SamHQProcessor(image_processor=image_processor)
hf_model = SamHQModel(config)
hf_model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_model.load_state_dict(state_dict)
hf_model = hf_model.to(device)
# Test the model with a sample image
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[500, 375]]]
input_labels = [[1]]
# Basic test without prompts
inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)
with torch.no_grad():
hf_model(**inputs)
if model_name == "sam_hq_vit_b":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(device)
with torch.no_grad():
hf_model(**inputs)
elif model_name == "sam_hq_vit_h":
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(device)
with torch.no_grad():
hf_model(**inputs)
input_boxes = [[[75.0, 275.0, 1725.0, 850.0]]]
inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)
with torch.no_grad():
hf_model(**inputs)
input_points = [[[400, 650], [800, 650]]]
input_labels = [[1, 1]]
inputs = processor(
images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(device)
with torch.no_grad():
hf_model(**inputs)
if pytorch_dump_folder is not None:
processor.save_pretrained(pytorch_dump_folder)
hf_model.save_pretrained(pytorch_dump_folder)
if push_to_hub:
repo_id = f"{hub_path}/{model_name}"
processor.push_to_hub(repo_id)
hf_model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = ["sam_hq_vit_b", "sam_hq_vit_h", "sam_hq_vit_l"]
parser.add_argument(
"--model_name",
choices=choices,
type=str,
required=True,
help="Name of the SAM-HQ model to convert",
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=False,
help="Path to the SAM-HQ checkpoint (.pth file)",
)
parser.add_argument(
"--pytorch_dump_folder_path",
type=str,
default=None,
help="Path to save the converted model",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the converted model to the hub",
)
parser.add_argument(
"--hub_path",
type=str,
default="sushmanth",
help="Hugging Face Hub path where the model will be uploaded",
)
args = parser.parse_args()
checkpoint_path = args.checkpoint_path
if checkpoint_path is None:
checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth")
convert_sam_hq_checkpoint(
args.model_name,
checkpoint_path,
args.pytorch_dump_folder_path,
args.push_to_hub,
args.hub_path,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,737 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...utils import add_start_docstrings, logging
from ..sam.configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from ..sam.modeling_sam import (
SamFeedForward,
SamImageSegmentationOutput,
SamLayerNorm,
SamModel,
SamPreTrainedModel,
SamTwoWayTransformer,
SamVisionEncoder,
SamVisionEncoderOutput,
SamVisionModel,
)
logger = logging.get_logger(__name__)
class SamHQPromptEncoderConfig(SamPromptEncoderConfig):
r"""
This is the configuration class to store the configuration of a [`SamHQPromptEncoderModel`].The [`SamHQPromptEncoderModel`]
module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield a
similar configuration to that of the SAM_HQ model. The configuration is used to store the configuration of the model.
[Uminosachi/sam-hq](https://huggingface.co/Uminosachi/sam-hq) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model's output.Read the documentation from
[`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 256):
Dimensionality of the hidden states.
image_size (`int`, *optional*, defaults to 1024):
The expected output resolution of the image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
mask_input_channels (`int`, *optional*, defaults to 16):
The number of channels to be fed to the `MaskDecoder` module.
num_point_embeddings (`int`, *optional*, defaults to 4):
The number of point embeddings to be used.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function in the encoder and pooler.
"""
pass
class SamHQVisionConfig(SamVisionConfig):
pass
class SamHQMaskDecoderConfig(SamMaskDecoderConfig):
r"""
vit_dim (`int`, *optional*, defaults to 768):
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
"""
def __init__(
self,
vit_dim=768,
**super_kwargs,
):
super().__init__(**super_kwargs)
self.vit_dim = vit_dim
class SamHQConfig(SamConfig):
r"""
[`SamHQConfig`] is the configuration class to store the configuration of a [`SamHQModel`]. It is used to instantiate a
SAM-HQ model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
SAM-HQ-ViT-H [sushmanth/sam_hq_vit_h](https://huggingface.co/sushmanth/sam_hq_vit_h) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (Union[`dict`, `SamHQVisionConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQVisionConfig`].
prompt_encoder_config (Union[`dict`, `SamHQPromptEncoderConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQPromptEncoderConfig`].
mask_decoder_config (Union[`dict`, `SamHQMaskDecoderConfig`], *optional*):
Dictionary of configuration options used to initialize [`SamHQMaskDecoderConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
"""
pass
@dataclass
class SamHQVisionEncoderOutput(SamVisionEncoderOutput):
"""
intermediate_embeddings (`list(torch.FloatTensor)`, *optional*):
A list of intermediate embeddings collected from certain blocks within the model, typically those without
windowed attention. Each element in the list is of shape `(batch_size, sequence_length, hidden_size)`.
This is specific to SAM-HQ and not present in base SAM.
"""
intermediate_embeddings: Optional[List[torch.FloatTensor]] = None
@dataclass
class SamHQImageSegmentationOutput(SamImageSegmentationOutput):
pass
class SamHQVisionEncoder(SamVisionEncoder):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SamHQVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
intermediate_embeddings = []
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
# Collect embeddings from non-windowed blocks
if hasattr(layer_module, "window_size") and layer_module.window_size == 0:
intermediate_embeddings.append(hidden_states)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states, intermediate_embeddings)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return SamHQVisionEncoderOutput(
last_hidden_state=hidden_states,
intermediate_embeddings=intermediate_embeddings,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class SamHQLayerNorm(SamLayerNorm):
pass
class SamHQTwoWayTransformer(SamTwoWayTransformer):
pass
class SamHQFeedForward(SamFeedForward):
pass
class SamHQMaskDecoder(nn.Module):
def __init__(self, config: SamHQMaskDecoderConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
self.transformer = SamHQTwoWayTransformer(config)
self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
self.upscale_layer_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first")
self.activation = nn.GELU()
mlps_list = []
for _ in range(self.num_mask_tokens):
mlps_list += [SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
self.iou_prediction_head = SamHQFeedForward(
self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
)
self.hq_token = nn.Embedding(1, self.hidden_size)
self.hq_mask_mlp = SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)
self.num_mask_tokens = self.num_mask_tokens + 1
# Compress ViT features
self.compress_vit_conv1 = nn.ConvTranspose2d(config.vit_dim, self.hidden_size, kernel_size=2, stride=2)
self.compress_vit_norm = SamHQLayerNorm(self.hidden_size, data_format="channels_first")
self.compress_vit_conv2 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2)
# Embedding encoder
self.encoder_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
self.encoder_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first")
self.encoder_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
# Embedding mask feature
self.mask_conv1 = nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, kernel_size=3, stride=1, padding=1)
self.mask_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first")
self.mask_conv2 = nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=3, stride=1, padding=1)
def forward(
self,
image_embeddings: torch.Tensor,
image_positional_embeddings: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
hq_token_only: bool,
intermediate_embeddings: Optional[List[torch.Tensor]] = None,
output_attentions: Optional[bool] = None,
attention_similarity: Optional[torch.Tensor] = None,
target_embedding: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict high-quality masks given image and prompt embeddings.
Args:
image_embeddings (`torch.Tensor`):
The embeddings from the image encoder.
image_positional_embedding (`torch.Tensor`):
Positional encoding with the shape of image_embeddings.
sparse_prompt_embeddings (`torch.Tensor`):
The embeddings of the points and boxes.
dense_prompt_embeddings (`torch.Tensor`):
The embeddings of the mask inputs.
multimask_output (bool):
Whether to return multiple masks or a single mask.
hq_token_only (bool):
Whether to use only the high-quality token output or combine with SAM output.
intermediate_embeddings (`torch.Tensor`):
Intermediate embeddings from the vision encoder for feature fusion.
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
attention_similarity (`torch.Tensor`, *optional*):
Optional tensor for attention similarity computation.
target_embedding (`torch.Tensor`, *optional*):
Optional target embedding for transformer processing.
Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple of tensors containing:
- A tensor of shape `(batch_size, num_prompts, num_masks, height, width)` containing the output masks.
- A tensor of shape `(batch_size, num_prompts, num_masks)` containing the iou predictions for each mask.
- (Optional) A tuple containing attention tensors if output_attentions is True.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
if has_intermediate:
vit_features = intermediate_embeddings[0].permute(0, 3, 1, 2).contiguous()
embed_encode = self.encoder_conv1(image_embeddings)
embed_encode = self.activation(self.encoder_norm(embed_encode))
embed_encode = self.encoder_conv2(embed_encode)
if has_intermediate:
compressed_vit_features = self.compress_vit_conv1(vit_features)
compressed_vit_features = self.activation(self.compress_vit_norm(compressed_vit_features))
compressed_vit_features = self.compress_vit_conv2(compressed_vit_features)
hq_features = embed_encode + compressed_vit_features
else:
hq_features = embed_encode
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if torch.any(sparse_prompt_embeddings != 0):
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
else:
tokens = output_tokens
point_embeddings = tokens.to(self.iou_token.weight.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
image_embeddings = image_embeddings.transpose(2, 3).reshape(
batch_size * point_batch_size, num_channels, height, width
)
upscaled_embedding = self.upscale_conv1(image_embeddings)
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
upscaled_embedding_hq = self.mask_conv1(upscaled_embedding)
upscaled_embedding_hq = self.activation(self.mask_norm(upscaled_embedding_hq))
upscaled_embedding_hq = self.mask_conv2(upscaled_embedding_hq)
if hq_features.shape[0] == 1:
hq_features = hq_features.repeat(batch_size * point_batch_size, 1, 1, 1)
elif hq_features.shape[0] == batch_size and batch_size * point_batch_size != batch_size:
hq_features = hq_features.repeat_interleave(point_batch_size, 0)
upscaled_embedding_hq = upscaled_embedding_hq + hq_features
hyper_in_list = []
for mask_token_index in range(self.num_mask_tokens):
if mask_token_index < self.num_mask_tokens - 1:
current_mlp = self.output_hypernetworks_mlps[mask_token_index]
else:
current_mlp = self.hq_mask_mlp
hyper_in_list += [current_mlp(mask_tokens_out[:, :, mask_token_index, :])]
hyper_in = torch.stack(hyper_in_list, dim=2)
_, num_channels, height, width = upscaled_embedding.shape
upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
upscaled_embedding_hq = upscaled_embedding_hq.reshape(
batch_size, point_batch_size, num_channels, height * width
)
masks_sam = (hyper_in[:, :, : self.num_mask_tokens - 1] @ upscaled_embedding).reshape(
batch_size, point_batch_size, -1, height, width
)
masks_hq = (hyper_in[:, :, self.num_mask_tokens - 1 :] @ upscaled_embedding_hq).reshape(
batch_size, point_batch_size, -1, height, width
)
masks = torch.cat([masks_sam, masks_hq], dim=2)
iou_pred = self.iou_prediction_head(iou_token_out)
if multimask_output:
mask_slice = slice(1, self.num_mask_tokens - 1)
iou_pred = iou_pred[:, :, mask_slice]
# Sort the IoU scores in descending order and get indices
iou_pred_sorted, sort_indices = torch.sort(iou_pred, dim=2, descending=True)
# Reorder the masks according to sorted scores
masks_sam = masks[:, :, mask_slice, :, :]
masks_sam = torch.gather(
masks_sam,
2,
sort_indices[..., None, None].expand(-1, -1, -1, masks_sam.shape[3], masks_sam.shape[4]),
)
# Update iou_pred with sorted scores
iou_pred = iou_pred_sorted
else:
mask_slice = slice(0, 1)
iou_pred = iou_pred[:, :, mask_slice]
masks_sam = masks[:, :, mask_slice, :, :]
masks_hq = masks[:, :, slice(self.num_mask_tokens - 1, self.num_mask_tokens), :, :]
if hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
class SamHQPreTrainedModel(SamPreTrainedModel):
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, SamHQVisionEncoder):
if module.pos_embed is not None:
module.pos_embed.data.zero_()
SAM_HQ_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SamHQConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"""The vision model from SAM-HQ without any head or projection on top.""",
SAM_HQ_START_DOCSTRING,
)
class SamHQVisionModel(SamVisionModel):
pass
@add_start_docstrings(
"Segment Anything Model HQ (SAM-HQ) for generating masks,given an input image and",
" optional 2D location and bounding boxes.",
SAM_HQ_START_DOCSTRING,
)
class SamHQModel(SamModel):
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config):
super().__init__(config)
self.vision_encoder = SamHQVisionEncoder(config.vision_config)
self.mask_decoder = SamHQMaskDecoder(config.mask_decoder_config)
self.post_init()
@torch.no_grad()
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
vision_output = self.vision_encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
intermediate_embeddings = vision_output[1]
return image_embeddings, intermediate_embeddings
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_points: Optional[torch.FloatTensor] = None,
input_labels: Optional[torch.LongTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True,
hq_token_only: bool = False,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
intermediate_embeddings: Optional[List[torch.FloatTensor]] = None,
**kwargs,
) -> List[Dict[str, torch.Tensor]]:
r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for
details.
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
better results. The points can be obtained by passing a list of list of list to the processor that will
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
coordinates of the point. If a different number of points is passed either for each image, or for each
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
computation of the embedding will be skipped for these points using the labels.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
official implementation, there are 3 types of labels
- `1`: the point is a point that contains the object of interest
- `0`: the point is a point that does not contain the object of interest
- `-1`: the point corresponds to the background
We added the label:
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
The padding labels should be automatically done by the processor.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
In the order (`x1`, `y1`, `x2`, `y2`):
- `x1`: the x coordinate of the top left point of the input box
- `y1`: the y coordinate of the top left point of the input box
- `x2`: the x coordinate of the bottom right point of the input box
- `y2`: the y coordinate of the bottom right point of the input box
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
SAM_HQ model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
multimask_output (`bool`, *optional*):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
hq_token_only (`bool`, *optional*, defaults to `False`):
Whether to use only the HQ token path for mask generation. When False, combines both standard and HQ paths.
This is specific to SAM-HQ's architecture.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
intermediate_embeddings (`List[torch.FloatTensor]`, *optional*):
Intermediate embeddings from vision encoder's non-windowed blocks, used by SAM-HQ for enhanced mask quality.
Required when providing pre-computed image_embeddings instead of pixel_values.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained("sushmanth/sam_hq_vit_b")
>>> processor = AutoProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
>>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
>>> # Get high-quality segmentation mask
>>> outputs = model(**inputs)
>>> # For high-quality mask only
>>> outputs = model(**inputs, hq_token_only=True)
>>> # Postprocess masks
>>> masks = processor.post_process_masks(
... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
... )
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
if pixel_values is not None and image_embeddings is not None:
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
if input_points is not None and len(input_points.shape) != 4:
raise ValueError(
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`."
f" got {input_points.shape}."
)
if input_boxes is not None and len(input_boxes.shape) != 3:
raise ValueError(
"The input_boxes must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`."
f" got {input_boxes.shape}."
)
# Add validation for point and box batch sizes
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
point_batch_size, box_batch_size
)
)
image_positional_embeddings = self.get_image_wide_positional_embeddings()
# repeat with batch size
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if return_dict:
image_embeddings = vision_outputs.last_hidden_state
intermediate_embeddings = vision_outputs.intermediate_embeddings
if output_hidden_states:
vision_hidden_states = vision_outputs.hidden_states
if output_attentions:
vision_attentions = vision_outputs.attentions
else:
image_embeddings = vision_outputs[0]
intermediate_embeddings = vision_outputs[1]
if output_hidden_states:
vision_hidden_states = vision_outputs[2]
if output_attentions:
vision_attentions = vision_outputs[-1]
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
# Predict masks
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
hq_token_only=hq_token_only,
intermediate_embeddings=intermediate_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
if not return_dict:
output = (iou_predictions, low_res_masks)
if output_hidden_states:
output = output + (vision_hidden_states,)
if output_attentions:
output = output + (vision_attentions, mask_decoder_attentions)
return output
return SamHQImageSegmentationOutput(
iou_scores=iou_predictions,
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
)
__all__ = [
"SamHQVisionConfig",
"SamHQMaskDecoderConfig",
"SamHQPromptEncoderConfig",
"SamHQConfig",
"SamHQModel",
"SamHQPreTrainedModel",
"SamHQVisionModel",
]

View File

@ -0,0 +1,330 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for SAMHQ.
"""
from copy import deepcopy
from typing import List, Optional, Union
import numpy as np
from ...image_utils import ImageInput, VideoInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_torch_available
if is_torch_available():
import torch
class SamHQImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]
point_pad_value: Optional[int]
class SamHQProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamHQImagesKwargs
_defaults = {
"images_kwargs": {
"point_pad_value": None,
}
}
class SamHQProcessor(ProcessorMixin):
r"""
Constructs a SAM HQ processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
single processor.
[`SamHQProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
[`~SamImageProcessor.__call__`] for more information.
Args:
image_processor (`SamImageProcessor`):
An instance of [`SamImageProcessor`]. The image processor is a required input.
"""
attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]
def __init__(self, image_processor):
super().__init__(image_processor)
# Ensure image_processor is properly initialized
if not hasattr(self, "image_processor"):
raise ValueError("image_processor was not properly initialized")
if not hasattr(self.image_processor, "size"):
raise ValueError("image_processor.size is not set")
self.target_size = self.image_processor.size["longest_edge"]
def __call__(
self,
images: Optional[ImageInput] = None,
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
# arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context:
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
video: Optional[VideoInput] = None,
**kwargs: Unpack[SamHQProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided.
"""
output_kwargs = self._merge_kwargs(
SamHQProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
encoding_image_processor = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
original_sizes = encoding_image_processor["original_sizes"]
if hasattr(original_sizes, "numpy"):
original_sizes = original_sizes.numpy()
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
)
encoding_image_processor = self._normalize_and_convert(
encoding_image_processor,
original_sizes,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
)
return encoding_image_processor
def _normalize_and_convert(
self,
encoding_image_processor,
original_sizes,
input_points=None,
input_labels=None,
input_boxes=None,
return_tensors="pt",
point_pad_value=-10,
):
"""
Normalize and convert the image processor output to the expected format.
"""
# Process input points
if input_points is not None:
input_points = self._normalize_batch_coordinates(input_points, original_sizes)
if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(
input_points, input_labels, point_pad_value
)
input_points = np.array(input_points)
# Process input labels
if input_labels is not None:
input_labels = np.array(input_labels)
# Process input boxes
if input_boxes is not None:
input_boxes = self._normalize_batch_coordinates(input_boxes, original_sizes, is_bounding_box=True)
input_boxes = np.array(input_boxes)
# Update processor with converted inputs
if input_boxes is not None:
encoding_image_processor["input_boxes"] = self._to_tensor(input_boxes, 3, return_tensors)
if input_points is not None:
encoding_image_processor["input_points"] = self._to_tensor(input_points, 4, return_tensors)
if input_labels is not None:
encoding_image_processor["input_labels"] = self._to_tensor(input_labels, 3, return_tensors)
return encoding_image_processor
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
expected_nb_points = max([point.shape[0] for point in input_points])
processed_input_points = []
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = np.append(input_labels[i], [point_pad_value])
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels
def _normalize_coordinates(
self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H,W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
coords = deepcopy(coords).astype(float)
if is_bounding_box:
coords = coords.reshape(-1, 2, 2)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
if is_bounding_box:
coords = coords.reshape(-1, 4)
return coords
def _preprocess_input(self, inputs, error_message, expected_nesting=1, dtype=None):
"""
Preprocess input by converting torch tensors to numpy arrays and validating structure.
Args:
inputs: The input to process
error_message: Error message if validation fails
expected_nesting: Expected nesting level (1 for points/labels, 2 for boxes)
dtype: Optional data type for numpy array conversion
Returns:
Processed input as list of numpy arrays or None
"""
if inputs is None:
return None
# Convert torch tensor to list if applicable
if hasattr(inputs, "numpy"):
inputs = inputs.numpy().tolist()
# Validate structure based on expected nesting
valid = isinstance(inputs, list)
current = inputs
for _ in range(expected_nesting):
if not valid or not current:
break
valid = valid and isinstance(current[0], list)
current = current[0] if current else None
if not valid:
raise ValueError(error_message)
# Convert to numpy arrays
return [np.array(item, dtype=dtype) for item in inputs]
def _check_and_preprocess_points(
self,
input_points=None,
input_labels=None,
input_boxes=None,
):
r"""
Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
it is converted to a `numpy.ndarray` and then to a `list`.
"""
# Process each input type
input_points = self._preprocess_input(input_points, "Input points must be a list of list of floating points.")
input_labels = self._preprocess_input(input_labels, "Input labels must be a list of list integers.")
input_boxes = self._preprocess_input(
input_boxes,
"Input boxes must be a list of list of list of floating points.",
expected_nesting=2,
dtype=np.float32,
)
return input_points, input_labels, input_boxes
@property
def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names))
def post_process_masks(self, *args, **kwargs):
return self.image_processor.post_process_masks(*args, **kwargs)
def _to_tensor(self, array, min_dim, return_tensors):
"""
Convert numpy array to tensor and ensure proper dimensionality.
Args:
array: The numpy array to convert
min_dim: The minimum number of dimensions the result should have
return_tensors: The type of tensors to return (e.g., "pt" for PyTorch tensors)
Returns:
The converted array or tensor with proper dimensions
"""
if return_tensors == "pt":
array = torch.from_numpy(array)
return array.unsqueeze(1) if array.ndim < min_dim else array
return array
def _normalize_batch_coordinates(self, inputs, original_sizes, is_bounding_box=False):
"""
Normalize coordinates based on original sizes.
Args:
inputs: List of coordinate arrays
original_sizes: Original sizes of the images
is_bounding_box: Whether inputs are bounding boxes
Returns:
Normalized coordinates as list
"""
if len(original_sizes) != len(inputs):
# Use first original size for all inputs
return [
self._normalize_coordinates(self.target_size, item, original_sizes[0], is_bounding_box=is_bounding_box)
for item in inputs
]
else:
# Use paired original sizes for each input
return [
self._normalize_coordinates(self.target_size, item, size, is_bounding_box=is_bounding_box)
for item, size in zip(inputs, original_sizes)
]
__all__ = ["SamHQProcessor"]

View File

@ -189,7 +189,17 @@ class MaskGenerationPipeline(ChunkPipeline):
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
# Handle both SAM (single tensor) and SAM-HQ (tuple) outputs
if isinstance(embeddings, tuple):
image_embeddings, intermediate_embeddings = embeddings
model_inputs["intermediate_embeddings"] = intermediate_embeddings
else:
image_embeddings = embeddings
# TODO: Identifying the model by the type of its returned embeddings is brittle.
# Consider using a more robust method for distinguishing model types here.
model_inputs["image_embeddings"] = image_embeddings
n_points = grid_points.shape[1]

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,167 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs
if is_vision_available():
from PIL import Image
from transformers import AutoProcessor, SamHQProcessor, SamImageProcessor
if is_torch_available():
import torch
@require_vision
@require_torchvision
class SamHQProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = SamHQProcessor
@classmethod
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamHQProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
@classmethod
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
def prepare_mask_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
mask_inputs = [np.random.randint(255, size=(30, 400), dtype=np.uint8)]
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
return mask_inputs
def test_tokenizer_defaults_preserved_by_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_image_processor_defaults_preserved_by_image_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_chat_template_save_loading(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_kwargs_overrides_default_image_processor_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_kwargs_overrides_default_tokenizer_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_unstructured_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_unstructured_kwargs_batched(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_doubly_passed_kwargs(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_structured_kwargs_nested(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_structured_kwargs_nested_from_dict(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_save_load_pretrained_additional_features(self):
self.skipTest("SamHQProcessor does not have a tokenizer")
def test_image_processor_no_masks(self):
image_processor = self.get_image_processor()
processor = SamHQProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_feat_extract = image_processor(image_input, return_tensors="pt")
input_processor = processor(images=image_input, return_tensors="pt")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum().item(), input_processor[key].sum().item(), delta=1e-2)
for image in input_feat_extract.pixel_values:
self.assertEqual(image.shape, (3, 1024, 1024))
for original_size in input_feat_extract.original_sizes:
np.testing.assert_array_equal(original_size, np.array([30, 400]))
for reshaped_input_size in input_feat_extract.reshaped_input_sizes:
np.testing.assert_array_equal(
reshaped_input_size, np.array([77, 1024])
) # reshaped_input_size value is before padding
def test_image_processor_with_masks(self):
image_processor = self.get_image_processor()
processor = SamHQProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
mask_input = self.prepare_mask_inputs()
input_feat_extract = image_processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt")
input_processor = processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum().item(), input_processor[key].sum().item(), delta=1e-2)
for label in input_feat_extract.labels:
self.assertEqual(label.shape, (256, 256))
@require_torch
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamHQProcessor(image_processor=image_processor)
dummy_masks = [torch.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks, torch.tensor(original_sizes), torch.tensor(reshaped_input_size)
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))

View File

@ -105,6 +105,8 @@ SPECIAL_CASES_TO_ALLOW = {
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
# used internally to calculate `mlp_dim`
"SamVisionConfig": ["mlp_ratio"],
# used internally to calculate `mlp_dim`
"SamHQVisionConfig": ["mlp_ratio"],
# For (head) training, but so far not implemented
"ClapAudioConfig": ["num_classes"],
# Not used, but providing useful information to users

View File

@ -1040,6 +1040,7 @@ SPECIAL_MODEL_NAMES = {
"OpenAI GPT": "GPT",
"Perceiver": "Perceiver IO",
"SAM": "Segment Anything",
"SAM_HQ": "Segment Anything High Quality",
"ViT": "Vision Transformer (ViT)",
}

View File

@ -469,6 +469,8 @@ OBJECTS_TO_IGNORE = [
"SEWForCTC",
"SamConfig",
"SamPromptEncoderConfig",
"SamHQConfig",
"SamHQPromptEncoderConfig",
"SeamlessM4TConfig", # use of unconventional markdown
"SeamlessM4Tv2Config", # use of unconventional markdown
"Seq2SeqTrainingArguments",

View File

@ -235,6 +235,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"JukeboxVQVAE",
"JukeboxPrior",
"SamModel",
"SamHQModel",
"DPTForDepthEstimation",
"DecisionTransformerGPT2Model",
"GLPNForDepthEstimation",

View File

@ -209,6 +209,7 @@ docs/source/en/model_doc/roc_bert.md
docs/source/en/model_doc/roformer.md
docs/source/en/model_doc/rwkv.md
docs/source/en/model_doc/sam.md
docs/source/en/model_doc/sam_hq.md
docs/source/en/model_doc/segformer.md
docs/source/en/model_doc/sew-d.md
docs/source/en/model_doc/sew.md