From 6eb00dd2f0283f46d21ce9466d8d4e21dfd02550 Mon Sep 17 00:00:00 2001 From: Magnus <97634880+MagnusS0@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:46:05 +0100 Subject: [PATCH] Support for SDPA for SAM models (#34110) * feat: add support for sdpa and gradient checkpointing * fix: ruff format * fix: config sdpa * fix: sdpa layer naming convention * fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states * test: skip incompatible tests and fix loading issue with sdpa - Updated tests to skip cases flash and dynamic compile. - Minor adjustment to ensure correct loading of model with sdpa for dispatch test. * style: apply Ruff formatting * ruff fix again after rebase * [run-slow] sam * [run-slow] sam * refactor: Address review comments and improve sub-config handling in SAM model tests - Added attributes for sub_configs as per PR #34410. - Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config. - Added class attribute _is_composite=True to the tester class - test_sdpa_can_dispatch_composite_models added * [run-slow] sam * style: ruff * [run-slow] sam * style: ruff again ... * [run-slow] sam --- .../models/sam/configuration_sam.py | 11 ++ src/transformers/models/sam/modeling_sam.py | 167 +++++++++++++++++- tests/models/sam/test_modeling_sam.py | 83 +++++++-- tests/test_modeling_common.py | 26 +-- 4 files changed, 256 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/sam/configuration_sam.py b/src/transformers/models/sam/configuration_sam.py index b0045655d20..22a237615d1 100644 --- a/src/transformers/models/sam/configuration_sam.py +++ b/src/transformers/models/sam/configuration_sam.py @@ -46,6 +46,8 @@ class SamPromptEncoderConfig(PretrainedConfig): The non-linear activation function in the encoder and pooler. """ + base_config_key = "prompt_encoder_config" + def __init__( self, hidden_size=256, @@ -102,6 +104,8 @@ class SamMaskDecoderConfig(PretrainedConfig): """ + base_config_key = "mask_decoder_config" + def __init__( self, hidden_size=256, @@ -181,6 +185,8 @@ class SamVisionConfig(PretrainedConfig): hidden_size`. """ + base_config_key = "vision_config" + def __init__( self, hidden_size=768, @@ -278,6 +284,11 @@ class SamConfig(PretrainedConfig): ```""" model_type = "sam" + sub_configs = { + "prompt_encoder_config": SamPromptEncoderConfig, + "mask_decoder_config": SamMaskDecoderConfig, + "vision_config": SamVisionConfig, + } def __init__( self, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index c99fb9d7e86..b935bc9e421 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -246,6 +246,47 @@ class SamAttention(nn.Module): return out +class SamSdpaAttention(SamAttention): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. Using SDPA instead of the default attention. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__(config, downsample_rate) + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # Scaled dot product attention + attn_mask = None + if attention_similarity is not None: + attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) + + out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) + + # Get output + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +SAM_ATTENTION_CLASSES = { + "eager": SamAttention, + "sdpa": SamSdpaAttention, +} + + class SamTwoWayAttentionBlock(nn.Module): def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): """ @@ -266,18 +307,21 @@ class SamTwoWayAttentionBlock(nn.Module): self.hidden_size = config.hidden_size self.layer_norm_eps = config.layer_norm_eps - self.self_attn = SamAttention(config, downsample_rate=1) + self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate) + self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.mlp = SamMLPBlock(config) self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate) - + self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) self.skip_first_layer_pe = skip_first_layer_pe def forward( @@ -344,7 +388,7 @@ class SamTwoWayTransformer(nn.Module): for i in range(self.num_hidden_layers): self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) - self.final_attn_token_to_image = SamAttention(config) + self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config) self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) def forward( @@ -431,7 +475,7 @@ class SamFeedForward(nn.Module): class SamMaskDecoder(nn.Module): def __init__(self, config: SamMaskDecoderConfig): super().__init__() - + self.config = config self.hidden_size = config.hidden_size self.num_multimask_outputs = config.num_multimask_outputs @@ -856,11 +900,118 @@ class SamVisionAttention(nn.Module): return outputs +class SamVisionSdpaAttention(SamVisionAttention): + """ + Multi-head Attention block with relative position embeddings. + Using SDPA instead of the default attention. + """ + + def __init__(self, config, window_size): + super().__init__(config, window_size) + + def add_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + This method is reimplemented to follow the implementation in: + https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950 + This implementation is more memory efficient when using SDPA in the forward method. + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + rel_h = rel_h.unsqueeze(-1) + rel_w = rel_w.unsqueeze(-2) + rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1) + rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width) + + return rel_h, rel_w + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + rel_h, rel_w = None, None + if self.use_rel_pos: + rel_h, rel_w = self.add_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + query = query.view(batch_size, self.num_attention_heads, height * width, -1) + key = key.view(batch_size, self.num_attention_heads, height * width, -1) + value = value.view(batch_size, self.num_attention_heads, height * width, -1) + + if self.use_rel_pos: + rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) + rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) + attn_bias = (rel_h + rel_w).view( + batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4) + ) + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value) + + attn_output = ( + attn_output.view(batch_size, self.num_attention_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, height, width, -1) + ) + + attn_output = self.proj(attn_output) + + if output_attentions: + # For output_attentions, calculate the attention weights + attn_weights = (query @ key.transpose(-2, -1)) * self.scale + if attn_bias is not None: + attn_weights = attn_weights + attn_bias + attn_weights = F.softmax(attn_weights, dim=-1) + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +SAM_VISION_ATTENTION_CLASSES = { + "eager": SamVisionAttention, + "sdpa": SamVisionSdpaAttention, +} + + class SamVisionLayer(nn.Module): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attn = SamVisionAttention(config, window_size) + self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SamMLPBlock(config) self.window_size = window_size @@ -1071,6 +1222,8 @@ class SamPreTrainedModel(PreTrainedModel): base_model_prefix = "sam" main_input_name = "pixel_values" _no_split_modules = ["SamVisionAttention"] + supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 7faace0096c..351016716a0 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -14,12 +14,13 @@ # limitations under the License. """Testing suite for the PyTorch SAM model.""" +import tempfile import unittest import requests from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline -from transformers.testing_utils import cleanup, require_torch, slow, torch_device +from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_resize_embeddings = False test_head_masking = False test_torchscript = False + _is_composite = True # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working def is_pipeline_test_to_skip( @@ -311,22 +313,13 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = SamModelTester(self) - self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) - self.prompt_encoder_config_tester = ConfigTester( - self, - config_class=SamPromptEncoderConfig, - has_text_modality=False, - num_attention_heads=12, - num_hidden_layers=2, - ) - self.mask_decoder_config_tester = ConfigTester( - self, config_class=SamMaskDecoderConfig, has_text_modality=False + common_properties = ["initializer_range"] + self.config_tester = ConfigTester( + self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties ) def test_config(self): - self.vision_config_tester.run_common_tests() - self.prompt_encoder_config_tester.run_common_tests() - self.mask_decoder_config_tester.run_common_tests() + self.config_tester.run_common_tests() @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") def test_inputs_embeds(self): @@ -450,6 +443,68 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = SamModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_torch_sdpa + def test_sdpa_can_compile_dynamic(self): + self.skipTest(reason="SAM model can't be compiled dynamic yet") + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + + # Root model determines SDPA support + attn_impl = "sdpa" if model._supports_sdpa else "eager" + + # Check config propagation to submodels that support it + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl) + self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager") + + # Verify SDPA/eager layer presence + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + + if not has_sdpa and attn_impl == "sdpa": + raise ValueError("The SDPA model should have SDPA attention layers") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + def prepare_image(): img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 13eacc4a596..3aaf18c9454 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4202,16 +4202,20 @@ class ModelTesterMixin: outputs_eager = model_eager(**prepared_inputs) outputs_sdpa = model_sdpa(**prepared_inputs) - logits_eager = ( - outputs_eager.hidden_states[-1] - if not is_encoder_decoder - else outputs_eager.decoder_hidden_states[-1] - ) - logits_sdpa = ( - outputs_sdpa.hidden_states[-1] - if not is_encoder_decoder - else outputs_sdpa.decoder_hidden_states[-1] - ) + if hasattr(outputs_eager, "vision_hidden_states"): + logits_eager = outputs_eager.vision_hidden_states[-1] + logits_sdpa = outputs_sdpa.vision_hidden_states[-1] + else: + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] @@ -4287,6 +4291,8 @@ class ModelTesterMixin: ) if config.model_type in ["idefics", "idefics2", "idefics3"]: self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") + if config.model_type in ["sam"]: + self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings") model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: