mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Create and Expose SamVisionModel as public for better accessibility (#36493)
* move encoder below
* auto modeling
* write SamVisionTester
* fix vision attention shape
* fix SamVisionTest
* minor changes to SamVisionTest
* Revert "fix vision attention shape"
This reverts commit d2a4083ae5
.
* fix attention output shape in new tests
* remove encoder examples
* run modular on got_ocr2
* code formatting
* fix got_ocr2
* ruff fixes
* code quality
* add sam_vision in auto modeling and auto configuration
* remove composite test
* updated index.md
* add TFSamVisionEncoder to __init__
* fix public TFSamVisionEncoder
* remove outdated todo comment
* set test_torch_exportable
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
* rename: VisionEncoder -> VisionModel
* bring back original SamVisionEncoder
* rename back: VisionEncoderOutput -> VisionModelOutput
* undo changes in SamModelTester
* reuse SamVisionEncoder in SamVisionModel
---------
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
f99c279d20
commit
0710e9b1e8
@ -149,12 +149,24 @@ alt="drawing" width="900"/>
|
|||||||
[[autodoc]] SamImageProcessor
|
[[autodoc]] SamImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
## SamVisionModel
|
||||||
|
|
||||||
|
[[autodoc]] SamVisionModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
## SamModel
|
## SamModel
|
||||||
|
|
||||||
[[autodoc]] SamModel
|
[[autodoc]] SamModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## TFSamVisionModel
|
||||||
|
|
||||||
|
[[autodoc]] TFSamVisionModel
|
||||||
|
- call
|
||||||
|
|
||||||
|
|
||||||
## TFSamModel
|
## TFSamModel
|
||||||
|
|
||||||
[[autodoc]] TFSamModel
|
[[autodoc]] TFSamModel
|
||||||
|
@ -3589,6 +3589,7 @@ else:
|
|||||||
[
|
[
|
||||||
"SamModel",
|
"SamModel",
|
||||||
"SamPreTrainedModel",
|
"SamPreTrainedModel",
|
||||||
|
"SamVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.seamless_m4t"].extend(
|
_import_structure["models.seamless_m4t"].extend(
|
||||||
@ -4757,6 +4758,7 @@ else:
|
|||||||
[
|
[
|
||||||
"TFSamModel",
|
"TFSamModel",
|
||||||
"TFSamPreTrainedModel",
|
"TFSamPreTrainedModel",
|
||||||
|
"TFSamVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.segformer"].extend(
|
_import_structure["models.segformer"].extend(
|
||||||
@ -8431,6 +8433,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.sam import (
|
from .models.sam import (
|
||||||
SamModel,
|
SamModel,
|
||||||
SamPreTrainedModel,
|
SamPreTrainedModel,
|
||||||
|
SamVisionModel,
|
||||||
)
|
)
|
||||||
from .models.seamless_m4t import (
|
from .models.seamless_m4t import (
|
||||||
SeamlessM4TCodeHifiGan,
|
SeamlessM4TCodeHifiGan,
|
||||||
@ -9372,6 +9375,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.sam import (
|
from .models.sam import (
|
||||||
TFSamModel,
|
TFSamModel,
|
||||||
TFSamPreTrainedModel,
|
TFSamPreTrainedModel,
|
||||||
|
TFSamVisionModel,
|
||||||
)
|
)
|
||||||
from .models.segformer import (
|
from .models.segformer import (
|
||||||
TFSegformerDecodeHead,
|
TFSegformerDecodeHead,
|
||||||
|
@ -273,6 +273,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("rt_detr_v2", "RTDetrV2Config"),
|
("rt_detr_v2", "RTDetrV2Config"),
|
||||||
("rwkv", "RwkvConfig"),
|
("rwkv", "RwkvConfig"),
|
||||||
("sam", "SamConfig"),
|
("sam", "SamConfig"),
|
||||||
|
("sam_vision_model", "SamVisionConfig"),
|
||||||
("seamless_m4t", "SeamlessM4TConfig"),
|
("seamless_m4t", "SeamlessM4TConfig"),
|
||||||
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
|
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
|
||||||
("segformer", "SegformerConfig"),
|
("segformer", "SegformerConfig"),
|
||||||
@ -630,6 +631,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("rt_detr_v2", "RT-DETRv2"),
|
("rt_detr_v2", "RT-DETRv2"),
|
||||||
("rwkv", "RWKV"),
|
("rwkv", "RWKV"),
|
||||||
("sam", "SAM"),
|
("sam", "SAM"),
|
||||||
|
("sam_vision_model", "SamVisionModel"),
|
||||||
("seamless_m4t", "SeamlessM4T"),
|
("seamless_m4t", "SeamlessM4T"),
|
||||||
("seamless_m4t_v2", "SeamlessM4Tv2"),
|
("seamless_m4t_v2", "SeamlessM4Tv2"),
|
||||||
("segformer", "SegFormer"),
|
("segformer", "SegFormer"),
|
||||||
@ -773,6 +775,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
|||||||
("chinese_clip_vision_model", "chinese_clip"),
|
("chinese_clip_vision_model", "chinese_clip"),
|
||||||
("rt_detr_resnet", "rt_detr"),
|
("rt_detr_resnet", "rt_detr"),
|
||||||
("granitevision", "llava_next"),
|
("granitevision", "llava_next"),
|
||||||
|
("sam_vision_model", "sam"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -249,6 +249,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("rt_detr_v2", "RTDetrV2Model"),
|
("rt_detr_v2", "RTDetrV2Model"),
|
||||||
("rwkv", "RwkvModel"),
|
("rwkv", "RwkvModel"),
|
||||||
("sam", "SamModel"),
|
("sam", "SamModel"),
|
||||||
|
("sam_vision_model", "SamVisionModel"),
|
||||||
("seamless_m4t", "SeamlessM4TModel"),
|
("seamless_m4t", "SeamlessM4TModel"),
|
||||||
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
|
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
|
||||||
("segformer", "SegformerModel"),
|
("segformer", "SegformerModel"),
|
||||||
|
@ -80,6 +80,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
||||||
("roformer", "TFRoFormerModel"),
|
("roformer", "TFRoFormerModel"),
|
||||||
("sam", "TFSamModel"),
|
("sam", "TFSamModel"),
|
||||||
|
("sam_vision_model", "TFSamVisionModel"),
|
||||||
("segformer", "TFSegformerModel"),
|
("segformer", "TFSegformerModel"),
|
||||||
("speech_to_text", "TFSpeech2TextModel"),
|
("speech_to_text", "TFSpeech2TextModel"),
|
||||||
("swiftformer", "TFSwiftFormerModel"),
|
("swiftformer", "TFSwiftFormerModel"),
|
||||||
|
@ -183,9 +183,27 @@ class SamVisionConfig(PretrainedConfig):
|
|||||||
mlp_dim (`int`, *optional*):
|
mlp_dim (`int`, *optional*):
|
||||||
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
|
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
|
||||||
hidden_size`.
|
hidden_size`.
|
||||||
"""
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import (
|
||||||
|
... SamVisionConfig,
|
||||||
|
... SamVisionModel,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
|
||||||
|
>>> configuration = SamVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
|
||||||
|
>>> model = SamVisionModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
base_config_key = "vision_config"
|
base_config_key = "vision_config"
|
||||||
|
model_type = "sam_vision_model"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -27,7 +27,13 @@ from torch import Tensor, nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import (
|
||||||
|
ModelOutput,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||||
|
|
||||||
|
|
||||||
@ -1280,6 +1286,61 @@ SAM_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SAM_VISION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""The vision model from Sam without any head or projection on top.""",
|
||||||
|
SAM_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class SamVisionModel(SamPreTrainedModel):
|
||||||
|
config_class = SamVisionConfig
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
|
def __init__(self, config: SamVisionConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.vision_encoder = SamVisionEncoder(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.vision_encoder.patch_embed
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SamVisionEncoderOutput]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.vision_encoder(
|
||||||
|
pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
||||||
" optional 2D location and bounding boxes.",
|
" optional 2D location and bounding boxes.",
|
||||||
@ -1522,4 +1583,4 @@ class SamModel(SamPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SamModel", "SamPreTrainedModel"]
|
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
|
||||||
|
@ -30,7 +30,13 @@ from ...activations_tf import ACT2FN
|
|||||||
from ...modeling_tf_outputs import TFBaseModelOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput
|
||||||
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
|
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
|
||||||
from ...tf_utils import flatten, functional_layernorm
|
from ...tf_utils import flatten, functional_layernorm
|
||||||
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import (
|
||||||
|
ModelOutput,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||||
|
|
||||||
|
|
||||||
@ -1400,6 +1406,70 @@ SAM_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SAM_VISION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
|
||||||
|
details.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""The vision model from Sam without any head or projection on top.""",
|
||||||
|
SAM_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFSamVisionModel(TFSamPreTrainedModel):
|
||||||
|
config_class = SamVisionConfig
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
|
def __init__(self, config: SamVisionConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
|
||||||
|
|
||||||
|
def build(self, input_shape=None):
|
||||||
|
if self.built:
|
||||||
|
return
|
||||||
|
self.built = True
|
||||||
|
if getattr(self, "vision_encoder", None) is not None:
|
||||||
|
with tf.name_scope(self.vision_encoder.name):
|
||||||
|
self.vision_encoder.build(None)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.vision_encoder.patch_embed
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
pixel_values: TFModelInputType | None = None,
|
||||||
|
output_attentions: bool | None = None,
|
||||||
|
output_hidden_states: bool | None = None,
|
||||||
|
return_dict: bool | None = None,
|
||||||
|
training: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.vision_encoder(
|
||||||
|
pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
|
||||||
" optional 2D location and bounding boxes.",
|
" optional 2D location and bounding boxes.",
|
||||||
@ -1653,4 +1723,4 @@ class TFSamModel(TFSamPreTrainedModel):
|
|||||||
self.mask_decoder.build(None)
|
self.mask_decoder.build(None)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["TFSamModel", "TFSamPreTrainedModel"]
|
__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
|
||||||
|
@ -8836,6 +8836,13 @@ class SamPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SamVisionModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class SeamlessM4TCodeHifiGan(metaclass=DummyObject):
|
class SeamlessM4TCodeHifiGan(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -2375,6 +2375,13 @@ class TFSamPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamVisionModel(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFSegformerDecodeHead(metaclass=DummyObject):
|
class TFSegformerDecodeHead(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
@ -32,13 +32,243 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import SamModel, SamProcessor
|
from transformers import SamModel, SamProcessor, SamVisionModel
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class SamVisionModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
hidden_size=36,
|
||||||
|
intermediate_size=72,
|
||||||
|
projection_dim=62,
|
||||||
|
output_channels=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=24,
|
||||||
|
patch_size=2,
|
||||||
|
hidden_act="gelu",
|
||||||
|
layer_norm_eps=1e-06,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
use_abs_pos=True,
|
||||||
|
use_rel_pos=True,
|
||||||
|
rel_pos_zero_init=False,
|
||||||
|
window_size=14,
|
||||||
|
global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
num_pos_feats=16,
|
||||||
|
mlp_dim=None,
|
||||||
|
batch_size=2,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.output_channels = output_channels
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_factor = initializer_factor
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.use_abs_pos = use_abs_pos
|
||||||
|
self.use_rel_pos = use_rel_pos
|
||||||
|
self.rel_pos_zero_init = rel_pos_zero_init
|
||||||
|
self.window_size = window_size
|
||||||
|
self.global_attn_indexes = global_attn_indexes
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.mlp_dim = mlp_dim
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = num_patches + 1
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return SamVisionConfig(
|
||||||
|
image_size=self.image_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
projection_dim=self.projection_dim,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=self.attention_dropout,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
output_channels=self.output_channels,
|
||||||
|
qkv_bias=self.qkv_bias,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
use_abs_pos=self.use_abs_pos,
|
||||||
|
use_rel_pos=self.use_rel_pos,
|
||||||
|
rel_pos_zero_init=self.rel_pos_zero_init,
|
||||||
|
window_size=self.window_size,
|
||||||
|
global_attn_indexes=self.global_attn_indexes,
|
||||||
|
num_pos_feats=self.num_pos_feats,
|
||||||
|
mlp_dim=self.mlp_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = SamVisionModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
|
output_size = self.image_size // self.patch_size
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape, (self.batch_size, self.output_channels, output_size, output_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SamVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
|
||||||
|
attention_mask and seq_length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (SamVisionModel,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_torchscript = False
|
||||||
|
test_torch_exportable = True
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = SamVisionModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
expected_attention_shape = (
|
||||||
|
self.model_tester.batch_size * self.model_tester.num_attention_heads,
|
||||||
|
196,
|
||||||
|
196,
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-4:]),
|
||||||
|
list(expected_attention_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="SamVisionModel does not support training")
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SamVisionModel does not support training")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SamVisionModel does not support training")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
self.skipTest(reason="SAM model can't be compiled dynamic yet")
|
||||||
|
|
||||||
|
|
||||||
class SamPromptEncoderTester:
|
class SamPromptEncoderTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -34,13 +34,204 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import SamProcessor, TFSamModel
|
from transformers import SamProcessor, TFSamModel, TFSamVisionModel
|
||||||
from transformers.modeling_tf_utils import keras
|
from transformers.modeling_tf_utils import keras
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamVisionModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
hidden_size=36,
|
||||||
|
intermediate_size=72,
|
||||||
|
projection_dim=62,
|
||||||
|
output_channels=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=24,
|
||||||
|
patch_size=2,
|
||||||
|
hidden_act="gelu",
|
||||||
|
layer_norm_eps=1e-06,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
use_abs_pos=True,
|
||||||
|
use_rel_pos=True,
|
||||||
|
rel_pos_zero_init=False,
|
||||||
|
window_size=14,
|
||||||
|
global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
num_pos_feats=16,
|
||||||
|
mlp_dim=None,
|
||||||
|
batch_size=2,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.output_channels = output_channels
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_factor = initializer_factor
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.use_abs_pos = use_abs_pos
|
||||||
|
self.use_rel_pos = use_rel_pos
|
||||||
|
self.rel_pos_zero_init = rel_pos_zero_init
|
||||||
|
self.window_size = window_size
|
||||||
|
self.global_attn_indexes = global_attn_indexes
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.mlp_dim = mlp_dim
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return SamVisionConfig(
|
||||||
|
image_size=self.image_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
projection_dim=self.projection_dim,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=self.attention_dropout,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
output_channels=self.output_channels,
|
||||||
|
qkv_bias=self.qkv_bias,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
use_abs_pos=self.use_abs_pos,
|
||||||
|
use_rel_pos=self.use_rel_pos,
|
||||||
|
rel_pos_zero_init=self.rel_pos_zero_init,
|
||||||
|
window_size=self.window_size,
|
||||||
|
global_attn_indexes=self.global_attn_indexes,
|
||||||
|
num_pos_feats=self.num_pos_feats,
|
||||||
|
mlp_dim=self.mlp_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = TFSamVisionModel(config=config)
|
||||||
|
result = model(pixel_values)
|
||||||
|
output_size = self.image_size // self.patch_size
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape, (self.batch_size, self.output_channels, output_size, output_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFSamVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
|
||||||
|
attention_mask and seq_length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (TFSamVisionModel,) if is_tf_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFSamVisionModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (keras.layers.Layer))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, keras.layers.Dense))
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.call)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
expected_attention_shape = (
|
||||||
|
self.model_tester.batch_size * self.model_tester.num_attention_heads,
|
||||||
|
196,
|
||||||
|
196,
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-4:]),
|
||||||
|
list(expected_attention_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TFSamPromptEncoderTester:
|
class TFSamPromptEncoderTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user