diff --git a/docs/source/en/model_doc/sam.md b/docs/source/en/model_doc/sam.md index cd9e3f5c3c4..58cbfbfb219 100644 --- a/docs/source/en/model_doc/sam.md +++ b/docs/source/en/model_doc/sam.md @@ -149,12 +149,24 @@ alt="drawing" width="900"/> [[autodoc]] SamImageProcessor +## SamVisionModel + +[[autodoc]] SamVisionModel + - forward + + ## SamModel [[autodoc]] SamModel - forward +## TFSamVisionModel + +[[autodoc]] TFSamVisionModel + - call + + ## TFSamModel [[autodoc]] TFSamModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d07618a4926..82b57928f3a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3589,6 +3589,7 @@ else: [ "SamModel", "SamPreTrainedModel", + "SamVisionModel", ] ) _import_structure["models.seamless_m4t"].extend( @@ -4757,6 +4758,7 @@ else: [ "TFSamModel", "TFSamPreTrainedModel", + "TFSamVisionModel", ] ) _import_structure["models.segformer"].extend( @@ -8431,6 +8433,7 @@ if TYPE_CHECKING: from .models.sam import ( SamModel, SamPreTrainedModel, + SamVisionModel, ) from .models.seamless_m4t import ( SeamlessM4TCodeHifiGan, @@ -9372,6 +9375,7 @@ if TYPE_CHECKING: from .models.sam import ( TFSamModel, TFSamPreTrainedModel, + TFSamVisionModel, ) from .models.segformer import ( TFSegformerDecodeHead, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index f5673e09cb7..9937b55a8b0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -273,6 +273,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("rt_detr_v2", "RTDetrV2Config"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("sam_vision_model", "SamVisionConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), ("segformer", "SegformerConfig"), @@ -630,6 +631,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("rt_detr_v2", "RT-DETRv2"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("sam_vision_model", "SamVisionModel"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), ("segformer", "SegFormer"), @@ -773,6 +775,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), ("granitevision", "llava_next"), + ("sam_vision_model", "sam"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7498f6d22c1..5df90ff5a87 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -249,6 +249,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("rt_detr_v2", "RTDetrV2Model"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("sam_vision_model", "SamVisionModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), ("segformer", "SegformerModel"), diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 906fe411d0f..67b69e2c610 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -80,6 +80,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), ("roformer", "TFRoFormerModel"), ("sam", "TFSamModel"), + ("sam_vision_model", "TFSamVisionModel"), ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), ("swiftformer", "TFSwiftFormerModel"), diff --git a/src/transformers/models/sam/configuration_sam.py b/src/transformers/models/sam/configuration_sam.py index e0a759dbf11..0f4a2b0893a 100644 --- a/src/transformers/models/sam/configuration_sam.py +++ b/src/transformers/models/sam/configuration_sam.py @@ -183,9 +183,27 @@ class SamVisionConfig(PretrainedConfig): mlp_dim (`int`, *optional*): The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * hidden_size`. - """ + + Example: + + ```python + >>> from transformers import ( + ... SamVisionConfig, + ... SamVisionModel, + ... ) + + >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration + >>> configuration = SamVisionConfig() + + >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration + >>> model = SamVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" base_config_key = "vision_config" + model_type = "sam_vision_model" def __init__( self, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 7fdb7a81dd4..50e0e86ddee 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -27,7 +27,13 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -1280,6 +1286,61 @@ SAM_INPUTS_DOCSTRING = r""" """ +SAM_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The vision model from Sam without any head or projection on top.""", + SAM_START_DOCSTRING, +) +class SamVisionModel(SamPreTrainedModel): + config_class = SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig): + super().__init__(config) + self.vision_encoder = SamVisionEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_encoder.patch_embed + + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamVisionEncoderOutput]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + @add_start_docstrings( "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", " optional 2D location and bounding boxes.", @@ -1522,4 +1583,4 @@ class SamModel(SamPreTrainedModel): ) -__all__ = ["SamModel", "SamPreTrainedModel"] +__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"] diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 29a2335cb12..d0462499c5a 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -30,7 +30,13 @@ from ...activations_tf import ACT2FN from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs from ...tf_utils import flatten, functional_layernorm -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -1400,6 +1406,70 @@ SAM_INPUTS_DOCSTRING = r""" """ +SAM_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The vision model from Sam without any head or projection on top.""", + SAM_START_DOCSTRING, +) +class TFSamVisionModel(TFSamPreTrainedModel): + config_class = SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(config, **kwargs) + self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_encoder", None) is not None: + with tf.name_scope(self.vision_encoder.name): + self.vision_encoder.build(None) + + def get_input_embeddings(self): + return self.vision_encoder.patch_embed + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig) + def call( + self, + pixel_values: TFModelInputType | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + @add_start_docstrings( "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", " optional 2D location and bounding boxes.", @@ -1653,4 +1723,4 @@ class TFSamModel(TFSamPreTrainedModel): self.mask_decoder.build(None) -__all__ = ["TFSamModel", "TFSamPreTrainedModel"] +__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7d980c739a0..79b2fd4e232 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8836,6 +8836,13 @@ class SamPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class SamVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SeamlessM4TCodeHifiGan(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index cade9ade698..985445fbba5 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2375,6 +2375,13 @@ class TFSamPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFSamVisionModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFSegformerDecodeHead(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 3c7e5a34b62..0f19b29d902 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -32,13 +32,243 @@ if is_torch_available(): import torch from torch import nn - from transformers import SamModel, SamProcessor + from transformers import SamModel, SamProcessor, SamVisionModel if is_vision_available(): from PIL import Image +class SamVisionModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def get_config(self): + return SamVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def create_and_check_model(self, config, pixel_values): + model = SamVisionModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + output_size = self.image_size // self.patch_size + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_channels, output_size, output_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class SamVisionModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SamVisionModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + test_torch_exportable = True + + def setUp(self): + self.model_tester = SamVisionModelTester(self) + self.config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-4:]), + list(expected_attention_shape), + ) + + @unittest.skip(reason="SamVisionModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="SamVisionModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SamVisionModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + @require_torch_sdpa + def test_sdpa_can_compile_dynamic(self): + self.skipTest(reason="SAM model can't be compiled dynamic yet") + + class SamPromptEncoderTester: def __init__( self, diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 478c2c3873b..305f429ea0e 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -34,13 +34,204 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_tf_available(): import tensorflow as tf - from transformers import SamProcessor, TFSamModel + from transformers import SamProcessor, TFSamModel, TFSamVisionModel from transformers.modeling_tf_utils import keras if is_vision_available(): from PIL import Image +class TFSamVisionModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + def get_config(self): + return SamVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def create_and_check_model(self, config, pixel_values): + model = TFSamVisionModel(config=config) + result = model(pixel_values) + output_size = self.image_size // self.patch_size + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_channels, output_size, output_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFSamVisionModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFSamVisionModel,) if is_tf_available() else () + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = TFSamVisionModelTester(self) + self.config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (keras.layers.Layer)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, keras.layers.Dense)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-4:]), + list(expected_attention_shape), + ) + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + class TFSamPromptEncoderTester: def __init__( self,