transformers/tests/models/sam2/test_modeling_sam2.py
2025-07-02 20:37:38 +00:00

1105 lines
43 KiB
Python

# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch SAM2 model."""
import gc
import tempfile
import unittest
import requests
from transformers import (
Sam2Config,
Sam2MaskDecoderConfig,
Sam2MemoryAttentionConfig,
Sam2MemoryEncoderConfig,
Sam2Processor,
Sam2PromptEncoderConfig,
Sam2VisionConfig,
)
from transformers.testing_utils import (
backend_empty_cache,
require_torch,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available
from transformers.video_utils import load_video
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import Sam2Model, Sam2Processor, Sam2VisionModel
if is_vision_available():
from PIL import Image
class Sam2VisionModelTester:
def __init__(
self,
parent,
hidden_size=12,
num_channels=3,
image_size=128,
patch_kernel_size=7,
patch_stride=4,
patch_padding=3,
batch_size=2,
dim_mul=2.0,
stages=[1, 2, 7, 2],
backbone_channel_list=[96, 48, 24, 12],
backbone_feature_sizes=[[32, 32], [16, 16], [8, 8]],
fpn_hidden_size=32,
is_training=False,
):
self.parent = parent
self.hidden_size = hidden_size
self.image_size = image_size
self.num_channels = num_channels
self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.batch_size = batch_size
self.is_training = is_training
self.stages = stages
self.dim_mul = dim_mul
self.backbone_channel_list = backbone_channel_list
self.backbone_feature_sizes = backbone_feature_sizes
self.fpn_hidden_size = fpn_hidden_size
def get_config(self):
return Sam2VisionConfig(
hidden_size=self.hidden_size,
image_size=self.image_size,
patch_kernel_size=self.patch_kernel_size,
patch_stride=self.patch_stride,
patch_padding=self.patch_padding,
num_channels=self.num_channels,
stages=self.stages,
backbone_channel_list=self.backbone_channel_list,
backbone_feature_sizes=self.backbone_feature_sizes,
fpn_hidden_size=self.fpn_hidden_size,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def create_and_check_model(self, config, pixel_values):
model = Sam2VisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
output_size = self.image_size // self.patch_stride // (self.dim_mul * len(self.stages))
output_channels = self.hidden_size * self.dim_mul * len(self.stages)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, output_size, output_size, output_channels)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class Sam2VisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (Sam2VisionModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Sam2VisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=Sam2VisionConfig, has_text_modality=False)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Overriding as attention shape depends on window_size
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager")
config = model.config
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = sum(self.model_tester.stages)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
window_size = config.window_spec[0]
out_dim = config.hidden_size
patch_stride = config.patch_stride
num_windows = self.model_tester.batch_size * (config.image_size // (window_size * patch_stride)) ** 2
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-4:]),
[num_windows, window_size, window_size, out_dim],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-4:]),
[num_windows, window_size, window_size, out_dim],
)
# Overriding as attention shape depends on window_size
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class, image_size):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = sum(self.model_tester.stages) + 1
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-4:]),
[
self.model_tester.batch_size,
self.model_tester.image_size // self.model_tester.patch_stride,
self.model_tester.image_size // self.model_tester.patch_stride,
self.model_tester.hidden_size,
],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = self.model_tester.image_size
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class, image_size)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet")
class Sam2PromptEncoderTester:
def __init__(
self,
hidden_size=32,
input_image_size=128,
patch_size=16,
mask_input_channels=8,
num_point_embeddings=4,
hidden_act="gelu",
):
self.hidden_size = hidden_size
self.input_image_size = input_image_size
self.patch_size = patch_size
self.mask_input_channels = mask_input_channels
self.num_point_embeddings = num_point_embeddings
self.hidden_act = hidden_act
def get_config(self):
return Sam2PromptEncoderConfig(
image_size=self.input_image_size,
patch_size=self.patch_size,
mask_input_channels=self.mask_input_channels,
hidden_size=self.hidden_size,
num_point_embeddings=self.num_point_embeddings,
hidden_act=self.hidden_act,
)
def prepare_config_and_inputs(self):
dummy_points = floats_tensor([self.batch_size, 3, 2])
config = self.get_config()
return config, dummy_points
class Sam2MaskDecoderTester:
def __init__(
self,
hidden_size=32,
hidden_act="relu",
mlp_dim=64,
num_hidden_layers=2,
num_attention_heads=4,
attention_downsample_rate=2,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=32,
):
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_dim = mlp_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_downsample_rate = attention_downsample_rate
self.num_multimask_outputs = num_multimask_outputs
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
def get_config(self):
return Sam2MaskDecoderConfig(
hidden_size=self.hidden_size,
hidden_act=self.hidden_act,
mlp_dim=self.mlp_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
attention_downsample_rate=self.attention_downsample_rate,
num_multimask_outputs=self.num_multimask_outputs,
iou_head_depth=self.iou_head_depth,
iou_head_hidden_dim=self.iou_head_hidden_dim,
)
def prepare_config_and_inputs(self):
config = self.get_config()
dummy_inputs = {
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
}
return config, dummy_inputs
class Sam2MemoryEncoderTester:
def __init__(
self,
hidden_size=32,
num_heads=1,
num_channels=3,
image_size=64,
patch_kernel_size=2,
patch_stride=2,
patch_padding=1,
mask_downsampler_embed_dim=32,
memory_fuser_embed_dim=32,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_channels = num_channels
self.image_size = image_size
self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
self.memory_fuser_embed_dim = memory_fuser_embed_dim
def get_config(self):
return Sam2MemoryEncoderConfig(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_channels=self.num_channels,
image_size=self.image_size,
patch_kernel_size=self.patch_kernel_size,
patch_stride=self.patch_stride,
patch_padding=self.patch_padding,
mask_downsampler_embed_dim=self.mask_downsampler_embed_dim,
memory_fuser_embed_dim=self.memory_fuser_embed_dim,
)
def prepare_config_and_inputs(self):
config = self.get_config()
dummy_inputs = {
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
}
return config, dummy_inputs
class Sam2ModelTester:
def __init__(
self,
parent,
num_channels=3,
image_size=128,
hidden_size=12,
patch_kernel_size=7,
patch_stride=4,
patch_padding=3,
dim_mul=2.0,
stages=[1, 2, 7, 2],
backbone_channel_list=[96, 48, 24, 12],
backbone_feature_sizes=[[32, 32], [16, 16], [8, 8]],
fpn_hidden_size=32,
batch_size=2,
is_training=False,
):
self.parent = parent
self.image_size = image_size
self.hidden_size = hidden_size
self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.dim_mul = dim_mul
self.stages = stages
self.backbone_channel_list = backbone_channel_list
self.backbone_feature_sizes = backbone_feature_sizes
self.fpn_hidden_size = fpn_hidden_size
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.prompt_encoder_tester = Sam2PromptEncoderTester()
self.mask_decoder_tester = Sam2MaskDecoderTester()
self.memory_encoder_tester = Sam2MemoryEncoderTester()
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
vision_config = Sam2VisionConfig(
hidden_size=self.hidden_size,
num_channels=self.num_channels,
image_size=self.image_size,
patch_kernel_size=self.patch_kernel_size,
patch_stride=self.patch_stride,
patch_padding=self.patch_padding,
dim_mul=self.dim_mul,
stages=self.stages,
backbone_channel_list=self.backbone_channel_list,
backbone_feature_sizes=self.backbone_feature_sizes,
fpn_hidden_size=self.fpn_hidden_size,
)
memory_attention_config = Sam2MemoryAttentionConfig(
hidden_size=self.hidden_size,
num_layers=1,
dim_feedforward=32,
)
prompt_encoder_config = self.prompt_encoder_tester.get_config()
mask_decoder_config = self.mask_decoder_tester.get_config()
memory_encoder_config = self.memory_encoder_tester.get_config()
return Sam2Config(
vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config,
mask_decoder_config=mask_decoder_config,
memory_attention_config=memory_attention_config,
memory_encoder_config=memory_encoder_config,
image_size=self.image_size,
)
def create_and_check_model(self, config, pixel_values):
model = Sam2Model(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
self.parent.assertEqual(result.low_res_masks.shape[:3], (self.batch_size, 1, 3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (Sam2Model,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Sam2Model, "mask-generation": Sam2Model} if is_torch_available() else {}
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = Sam2ModelTester(self)
common_properties = ["initializer_range"]
self.config_tester = ConfigTester(
self, config_class=Sam2Config, has_text_modality=False, common_properties=common_properties
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Overriding as attention shape depends on window_size
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager")
config = model.config
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.vision_attentions
expected_num_attentions = sum(self.model_tester.stages)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
window_size = config.vision_config.window_spec[0]
out_dim = self.model_tester.hidden_size
patch_stride = self.model_tester.patch_stride
num_windows = (
self.model_tester.batch_size * (self.model_tester.image_size // (window_size * patch_stride)) ** 2
)
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.vision_attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-4:]),
[num_windows, window_size, window_size, out_dim],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.vision_attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-4:]),
[num_windows, window_size, window_size, out_dim],
)
# Override as Sam2Model has different sub-modules
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are called "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device)
vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder")
mask_decoder_sdpa = getattr(model_sdpa, "mask_decoder")
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(mask_decoder_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(vision_encoder_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(getattr(model_eager, "mask_decoder").config._attn_implementation == "eager")
self.assertTrue(getattr(model_eager, "vision_encoder").config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
# Override as Sam2Model doesn't have hidden states
def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
r"""
Tests the equivalence between the eager and flash attention implementations.
This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
if padding_side == "left":
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
else:
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = outputs.vision_hidden_states[-1]
logits_fa = outputs_fa.vision_hidden_states[-1]
if padding_side == "left":
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
# Override as diffence slightly higher than the threshold
def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)
@unittest.skip(reason="Sam2Model does not support training")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Hidden_states is tested in sub modules tests")
def test_hidden_states_output(self):
pass
# @slow
# def test_model_from_pretrained(self):
# model_name = "facebook/sam-vit-huge"
# model = SamModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM2 model can't be compiled dynamic yet")
def prepare_image():
img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_dog_img():
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_video():
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
raw_video, _ = load_video(video_url)
return raw_video
@slow
class Sam2ModelIntegrationTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf", attn_implementation="sdpa")
self.processor = Sam2Processor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
self.model.to(torch_device)
self.model.eval()
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
backend_empty_cache(torch_device)
def test_inference_mask_generation_no_point(self):
pass
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
# self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_multimask(self):
raw_image = prepare_image()
input_points = [[[[500, 375]]]]
input_labels = [[[1]]]
inputs = self.processor(
images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(torch_device)
# to_tensor = ToTensor()
# transforms = torch.jit.script(
# nn.Sequential(
# Resize((1024, 1024)),
# Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# )
# )
# inputs["pixel_values"] = transforms(to_tensor(raw_image)).unsqueeze(0).to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
self.assertEqual(outputs.iou_scores.shape, (1, 1, 3))
self.assertEqual(outputs.low_res_masks.shape, (1, 1, 3, 256, 256))
sorted_indices = torch.argsort(outputs.iou_scores.squeeze(), descending=True)
scores = outputs.iou_scores.squeeze()[sorted_indices]
masks_logits = outputs.low_res_masks.squeeze()[sorted_indices][0, :3, :3]
torch.testing.assert_close(
scores, torch.tensor([0.9546, 0.4937, 0.0428]).to(torch_device), atol=1e-4, rtol=1e-4
)
torch.testing.assert_close(
masks_logits,
torch.tensor(
[[-25.0963, -41.5728, -30.8723], [-34.7112, -30.7988, -36.4013], [-25.3061, -37.4575, -33.1899]]
).to(torch_device),
atol=1e-4,
rtol=1e-4,
)
def test_inference_mask_generation_one_point_no_multimask(self):
raw_image = prepare_image()
input_points = [[[[500, 375]]]]
input_labels = [[[1]]]
inputs = self.processor(
images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(torch_device)
with torch.no_grad():
outputs = self.model(**inputs, multimask_output=False)
self.assertEqual(outputs.iou_scores.shape, (1, 1, 1))
self.assertEqual(outputs.low_res_masks.shape, (1, 1, 1, 256, 256))
scores = outputs.iou_scores.squeeze((0, 1))
masks_logits = outputs.low_res_masks.squeeze((0, 1))[0, :3, :3]
torch.testing.assert_close(scores, torch.tensor([0.9366]).to(torch_device), atol=1e-4, rtol=1e-4)
torch.testing.assert_close(
masks_logits,
torch.tensor(
[[-7.1674, -13.4459, -9.6908], [-10.6038, -9.7242, -12.4059], [-7.4478, -12.4997, -10.5906]]
).to(torch_device),
atol=1e-4,
rtol=1e-4,
)
def test_inference_mask_generation_video_one_point(self):
pass
# raw_video = prepare_video()
# self.processor.init_state(video_path="./videos/bedroom_light")
# inputs = processor.add_new_points_or_box(
# frame_idx=0,
# obj_id=1,
# points=[[[[210, 350]]]],
# labels=[[[1]]],
# )
# def test_inference_mask_generation_one_point_one_bb(self):
# model = Sam2Model.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
# processor = SamProcessor.from_pretrained("../sam2_hf_implem/sam2_tiny_hf")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_boxes = [[[[650, 900, 1000, 1250]]]]
# input_points = [[[[820, 1080]]]]
# inputs = processor(
# images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
# ).to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# masks = outputs.pred_masks[0, 0, 0, 0, :3]
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
# self.assertTrue(
# torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
# )
def test_inference_mask_generation_batched_points_batched_images(self):
raw_image1 = prepare_image()
raw_image2 = prepare_dog_img()
input_points = [[[[500, 375], [10, 10]]], [[[770, 200], [730, 120]]]]
input_labels = [[[1, -10]], [[1, 0]]]
inputs = self.processor(
images=[raw_image1, raw_image2], input_points=input_points, input_labels=input_labels, return_tensors="pt"
).to(torch_device)
with torch.no_grad():
outputs = self.model(**inputs)
self.assertEqual(outputs.iou_scores.shape, (2, 1, 3))
self.assertEqual(outputs.low_res_masks.shape, (2, 1, 3, 256, 256))
sorted_indices = torch.argsort(outputs.iou_scores[0].squeeze(), descending=True)
scores1 = outputs.iou_scores[0].squeeze()[sorted_indices]
masks_logits1 = outputs.low_res_masks[0].squeeze()[sorted_indices][0, :3, :3]
sorted_indices = torch.argsort(outputs.iou_scores[1].squeeze(), descending=True)
scores2 = outputs.iou_scores[1].squeeze()[sorted_indices]
masks_logits2 = outputs.low_res_masks[1].squeeze()[sorted_indices][0, :3, :3]
torch.testing.assert_close(
scores1, torch.tensor([0.9584, 0.4898, 0.0445]).to(torch_device), atol=1e-4, rtol=1e-4
)
torch.testing.assert_close(
masks_logits1,
torch.tensor(
[[-22.4127, -37.7623, -27.7642], [-31.0563, -27.6730, -32.6308], [-22.4559, -33.8773, -29.5238]]
).to(torch_device),
atol=1e-4,
rtol=1e-4,
)
torch.testing.assert_close(
scores2, torch.tensor([0.9504, 0.8117, 0.7426]).to(torch_device), atol=1e-4, rtol=1e-4
)
torch.testing.assert_close(
masks_logits2,
torch.tensor(
[[-13.1202, -17.3222, -14.9687], [-16.2375, -12.7737, -17.6353], [-13.5025, -17.1528, -15.6627]]
).to(torch_device),
atol=1e-4,
rtol=1e-4,
)
# def test_inference_mask_generation_one_point_one_bb_zero(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_boxes = [[[620, 900, 1000, 1255]]]
# input_points = [[[820, 1080]]]
# labels = [[0]]
# inputs = processor(
# images=raw_image,
# input_boxes=input_boxes,
# input_points=input_points,
# input_labels=labels,
# return_tensors="pt",
# ).to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
# def test_inference_mask_generation_two_points_batched(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_points = [[[400, 650], [800, 650]], [[400, 650]]]
# input_labels = [[1, 1], [1]]
# inputs = processor(
# images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt"
# ).to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
# self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
# def test_inference_mask_generation_one_box(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_boxes = [[[75, 275, 1725, 850]]]
# inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores = outputs.iou_scores.squeeze()
# self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
# def test_inference_mask_generation_batched_image_one_point(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# raw_dog_image = prepare_dog_img()
# input_points = [[[820, 1080]], [[220, 470]]]
# inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to(
# torch_device
# )
# with torch.no_grad():
# outputs = model(**inputs)
# scores_batched = outputs.iou_scores.squeeze()
# input_points = [[[220, 470]]]
# inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# scores_single = outputs.iou_scores.squeeze()
# self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
# def test_inference_mask_generation_two_points_point_batch(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip
# input_points = input_points.unsqueeze(0)
# inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# iou_scores = outputs.iou_scores.cpu()
# self.assertTrue(iou_scores.shape == (1, 2, 3))
# torch.testing.assert_close(
# iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
# )
# def test_inference_mask_generation_three_boxes_point_batch(self):
# model = Sam2Model.from_pretrained("facebook/sam2-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam2-vit-base")
# model.to(torch_device)
# model.eval()
# raw_image = prepare_image()
# # fmt: off
# input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu()
# EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
# [0.5996, 0.7661, 0.7937],
# [0.5996, 0.7661, 0.7937]]])
# # fmt: on
# input_boxes = input_boxes.unsqueeze(0)
# inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
# with torch.no_grad():
# outputs = model(**inputs)
# iou_scores = outputs.iou_scores.cpu()
# self.assertTrue(iou_scores.shape == (1, 3, 3))
# torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
# def test_dummy_pipeline_generation(self):
# generator = pipeline("mask-generation", model="facebook/sam2-vit-base", device=torch_device)
# raw_image = prepare_image()
# _ = generator(raw_image, points_per_batch=64)