mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Support for SDPA for SAM models (#34110)
* feat: add support for sdpa and gradient checkpointing * fix: ruff format * fix: config sdpa * fix: sdpa layer naming convention * fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states * test: skip incompatible tests and fix loading issue with sdpa - Updated tests to skip cases flash and dynamic compile. - Minor adjustment to ensure correct loading of model with sdpa for dispatch test. * style: apply Ruff formatting * ruff fix again after rebase * [run-slow] sam * [run-slow] sam * refactor: Address review comments and improve sub-config handling in SAM model tests - Added attributes for sub_configs as per PR #34410. - Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config. - Added class attribute _is_composite=True to the tester class - test_sdpa_can_dispatch_composite_models added * [run-slow] sam * style: ruff * [run-slow] sam * style: ruff again ... * [run-slow] sam
This commit is contained in:
parent
747f361da1
commit
6eb00dd2f0
@ -46,6 +46,8 @@ class SamPromptEncoderConfig(PretrainedConfig):
|
|||||||
The non-linear activation function in the encoder and pooler.
|
The non-linear activation function in the encoder and pooler.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
base_config_key = "prompt_encoder_config"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size=256,
|
hidden_size=256,
|
||||||
@ -102,6 +104,8 @@ class SamMaskDecoderConfig(PretrainedConfig):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
base_config_key = "mask_decoder_config"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size=256,
|
hidden_size=256,
|
||||||
@ -181,6 +185,8 @@ class SamVisionConfig(PretrainedConfig):
|
|||||||
hidden_size`.
|
hidden_size`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
base_config_key = "vision_config"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
@ -278,6 +284,11 @@ class SamConfig(PretrainedConfig):
|
|||||||
```"""
|
```"""
|
||||||
|
|
||||||
model_type = "sam"
|
model_type = "sam"
|
||||||
|
sub_configs = {
|
||||||
|
"prompt_encoder_config": SamPromptEncoderConfig,
|
||||||
|
"mask_decoder_config": SamMaskDecoderConfig,
|
||||||
|
"vision_config": SamVisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -246,6 +246,47 @@ class SamAttention(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SamSdpaAttention(SamAttention):
|
||||||
|
"""
|
||||||
|
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||||
|
values. Using SDPA instead of the default attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, downsample_rate=None):
|
||||||
|
super().__init__(config, downsample_rate)
|
||||||
|
|
||||||
|
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
|
||||||
|
# Input projections
|
||||||
|
query = self.q_proj(query)
|
||||||
|
key = self.k_proj(key)
|
||||||
|
value = self.v_proj(value)
|
||||||
|
|
||||||
|
point_batch_size = query.shape[1]
|
||||||
|
# Separate into heads
|
||||||
|
query = self._separate_heads(query, self.num_attention_heads)
|
||||||
|
key = self._separate_heads(key, self.num_attention_heads)
|
||||||
|
value = self._separate_heads(value, self.num_attention_heads)
|
||||||
|
|
||||||
|
# Scaled dot product attention
|
||||||
|
attn_mask = None
|
||||||
|
if attention_similarity is not None:
|
||||||
|
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
|
||||||
|
|
||||||
|
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
# Get output
|
||||||
|
out = self._recombine_heads(out, point_batch_size)
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
SAM_ATTENTION_CLASSES = {
|
||||||
|
"eager": SamAttention,
|
||||||
|
"sdpa": SamSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SamTwoWayAttentionBlock(nn.Module):
|
class SamTwoWayAttentionBlock(nn.Module):
|
||||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
||||||
"""
|
"""
|
||||||
@ -266,18 +307,21 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.layer_norm_eps = config.layer_norm_eps
|
self.layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
self.self_attn = SamAttention(config, downsample_rate=1)
|
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
|
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
config, downsample_rate=attention_downsample_rate
|
||||||
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.mlp = SamMLPBlock(config)
|
self.mlp = SamMLPBlock(config)
|
||||||
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
|
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
config, downsample_rate=attention_downsample_rate
|
||||||
|
)
|
||||||
self.skip_first_layer_pe = skip_first_layer_pe
|
self.skip_first_layer_pe = skip_first_layer_pe
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -344,7 +388,7 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
for i in range(self.num_hidden_layers):
|
for i in range(self.num_hidden_layers):
|
||||||
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
||||||
|
|
||||||
self.final_attn_token_to_image = SamAttention(config)
|
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||||
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -431,7 +475,7 @@ class SamFeedForward(nn.Module):
|
|||||||
class SamMaskDecoder(nn.Module):
|
class SamMaskDecoder(nn.Module):
|
||||||
def __init__(self, config: SamMaskDecoderConfig):
|
def __init__(self, config: SamMaskDecoderConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.num_multimask_outputs = config.num_multimask_outputs
|
self.num_multimask_outputs = config.num_multimask_outputs
|
||||||
@ -856,11 +900,118 @@ class SamVisionAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class SamVisionSdpaAttention(SamVisionAttention):
|
||||||
|
"""
|
||||||
|
Multi-head Attention block with relative position embeddings.
|
||||||
|
Using SDPA instead of the default attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, window_size):
|
||||||
|
super().__init__(config, window_size)
|
||||||
|
|
||||||
|
def add_decomposed_rel_pos(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
rel_pos_h: torch.Tensor,
|
||||||
|
rel_pos_w: torch.Tensor,
|
||||||
|
q_size: Tuple[int, int],
|
||||||
|
k_size: Tuple[int, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||||
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||||
|
This method is reimplemented to follow the implementation in:
|
||||||
|
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
|
||||||
|
This implementation is more memory efficient when using SDPA in the forward method.
|
||||||
|
Args:
|
||||||
|
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||||
|
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||||
|
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||||
|
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||||
|
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attn (Tensor): attention map with added relative positional embeddings.
|
||||||
|
"""
|
||||||
|
query_height, query_width = q_size
|
||||||
|
key_height, key_width = k_size
|
||||||
|
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
||||||
|
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
||||||
|
|
||||||
|
batch_size, _, dim = query.shape
|
||||||
|
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
||||||
|
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||||
|
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||||
|
rel_h = rel_h.unsqueeze(-1)
|
||||||
|
rel_w = rel_w.unsqueeze(-2)
|
||||||
|
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
|
||||||
|
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
|
||||||
|
|
||||||
|
return rel_h, rel_w
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||||
|
batch_size, height, width, _ = hidden_states.shape
|
||||||
|
# qkv with shape (3, B, nHead, H * W, C)
|
||||||
|
qkv = (
|
||||||
|
self.qkv(hidden_states)
|
||||||
|
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
||||||
|
.permute(2, 0, 3, 1, 4)
|
||||||
|
)
|
||||||
|
# q, k, v with shape (B * nHead, H * W, C)
|
||||||
|
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
||||||
|
|
||||||
|
rel_h, rel_w = None, None
|
||||||
|
if self.use_rel_pos:
|
||||||
|
rel_h, rel_w = self.add_decomposed_rel_pos(
|
||||||
|
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||||
|
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||||
|
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||||
|
|
||||||
|
if self.use_rel_pos:
|
||||||
|
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
|
||||||
|
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
|
||||||
|
attn_bias = (rel_h + rel_w).view(
|
||||||
|
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
|
||||||
|
)
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
|
||||||
|
else:
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
|
attn_output = (
|
||||||
|
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
|
||||||
|
.permute(0, 2, 3, 1, 4)
|
||||||
|
.reshape(batch_size, height, width, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
# For output_attentions, calculate the attention weights
|
||||||
|
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn_weights = attn_weights + attn_bias
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
outputs = (attn_output, attn_weights)
|
||||||
|
else:
|
||||||
|
outputs = (attn_output, None)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
SAM_VISION_ATTENTION_CLASSES = {
|
||||||
|
"eager": SamVisionAttention,
|
||||||
|
"sdpa": SamVisionSdpaAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SamVisionLayer(nn.Module):
|
class SamVisionLayer(nn.Module):
|
||||||
def __init__(self, config, window_size):
|
def __init__(self, config, window_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attn = SamVisionAttention(config, window_size)
|
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
|
||||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.mlp = SamMLPBlock(config)
|
self.mlp = SamMLPBlock(config)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
@ -1071,6 +1222,8 @@ class SamPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "sam"
|
base_model_prefix = "sam"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
_no_split_modules = ["SamVisionAttention"]
|
_no_split_modules = ["SamVisionAttention"]
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
@ -14,12 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch SAM model."""
|
"""Testing suite for the PyTorch SAM model."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
|
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
|
||||||
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
|
from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
_is_composite = True
|
||||||
|
|
||||||
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -311,22 +313,13 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SamModelTester(self)
|
self.model_tester = SamModelTester(self)
|
||||||
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
common_properties = ["initializer_range"]
|
||||||
self.prompt_encoder_config_tester = ConfigTester(
|
self.config_tester = ConfigTester(
|
||||||
self,
|
self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
|
||||||
config_class=SamPromptEncoderConfig,
|
|
||||||
has_text_modality=False,
|
|
||||||
num_attention_heads=12,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
)
|
|
||||||
self.mask_decoder_config_tester = ConfigTester(
|
|
||||||
self, config_class=SamMaskDecoderConfig, has_text_modality=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.vision_config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
self.prompt_encoder_config_tester.run_common_tests()
|
|
||||||
self.mask_decoder_config_tester.run_common_tests()
|
|
||||||
|
|
||||||
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
@ -450,6 +443,68 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model = SamModel.from_pretrained(model_name)
|
model = SamModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
self.skipTest(reason="SAM model can't be compiled dynamic yet")
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_sdpa_can_dispatch_composite_models(self):
|
||||||
|
"""
|
||||||
|
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
|
||||||
|
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
|
||||||
|
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
|
||||||
|
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
|
||||||
|
See https://github.com/huggingface/transformers/pull/32238 for more info
|
||||||
|
|
||||||
|
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
|
||||||
|
that has a different set of sub-configs has to overwrite this test.
|
||||||
|
"""
|
||||||
|
if not self.has_attentions:
|
||||||
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
|
if not self._is_composite:
|
||||||
|
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||||
|
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||||
|
|
||||||
|
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||||
|
model_eager = model_eager.eval().to(torch_device)
|
||||||
|
|
||||||
|
# Root model determines SDPA support
|
||||||
|
attn_impl = "sdpa" if model._supports_sdpa else "eager"
|
||||||
|
|
||||||
|
# Check config propagation to submodels that support it
|
||||||
|
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||||
|
self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
|
||||||
|
self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)
|
||||||
|
|
||||||
|
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||||
|
self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
|
||||||
|
self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")
|
||||||
|
|
||||||
|
# Verify SDPA/eager layer presence
|
||||||
|
has_sdpa = False
|
||||||
|
for name, submodule in model_sdpa.named_modules():
|
||||||
|
class_name = submodule.__class__.__name__
|
||||||
|
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||||
|
has_sdpa = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_sdpa and attn_impl == "sdpa":
|
||||||
|
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||||
|
|
||||||
|
for name, submodule in model_eager.named_modules():
|
||||||
|
class_name = submodule.__class__.__name__
|
||||||
|
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||||
|
raise ValueError("The eager model should not have SDPA attention layers")
|
||||||
|
|
||||||
|
|
||||||
def prepare_image():
|
def prepare_image():
|
||||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
|
@ -4202,16 +4202,20 @@ class ModelTesterMixin:
|
|||||||
outputs_eager = model_eager(**prepared_inputs)
|
outputs_eager = model_eager(**prepared_inputs)
|
||||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||||
|
|
||||||
logits_eager = (
|
if hasattr(outputs_eager, "vision_hidden_states"):
|
||||||
outputs_eager.hidden_states[-1]
|
logits_eager = outputs_eager.vision_hidden_states[-1]
|
||||||
if not is_encoder_decoder
|
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
|
||||||
else outputs_eager.decoder_hidden_states[-1]
|
else:
|
||||||
)
|
logits_eager = (
|
||||||
logits_sdpa = (
|
outputs_eager.hidden_states[-1]
|
||||||
outputs_sdpa.hidden_states[-1]
|
if not is_encoder_decoder
|
||||||
if not is_encoder_decoder
|
else outputs_eager.decoder_hidden_states[-1]
|
||||||
else outputs_sdpa.decoder_hidden_states[-1]
|
)
|
||||||
)
|
logits_sdpa = (
|
||||||
|
outputs_sdpa.hidden_states[-1]
|
||||||
|
if not is_encoder_decoder
|
||||||
|
else outputs_sdpa.decoder_hidden_states[-1]
|
||||||
|
)
|
||||||
|
|
||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
@ -4287,6 +4291,8 @@ class ModelTesterMixin:
|
|||||||
)
|
)
|
||||||
if config.model_type in ["idefics", "idefics2", "idefics3"]:
|
if config.model_type in ["idefics", "idefics2", "idefics3"]:
|
||||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||||
|
if config.model_type in ["sam"]:
|
||||||
|
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
Loading…
Reference in New Issue
Block a user