From c53d53da89c0617f7dd5a69a2a08e6b1232b35fd Mon Sep 17 00:00:00 2001 From: Armaghan Shakir Date: Mon, 17 Mar 2025 14:39:52 +0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20Fix=20sd?= =?UTF-8?q?pa=20in=20SAM=20and=20refactor=20relative=20position=20embeddin?= =?UTF-8?q?gs=20(#36422)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fall back to eager if output_attentions * improve relative position embeddings * run modular on got_ocr2 * run-slow: sam * fix run-length encoding * fix tf processor errors * update tf_sam * fix compile error * re-run tests --- .../models/got_ocr2/modeling_got_ocr2.py | 23 ++-- .../models/sam/image_processing_sam.py | 8 +- src/transformers/models/sam/modeling_sam.py | 105 ++++++------------ .../models/sam/modeling_tf_sam.py | 25 +++-- src/transformers/processing_utils.py | 2 +- tests/models/sam/test_processor_sam.py | 2 +- 6 files changed, 62 insertions(+), 103 deletions(-) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 918ee2bb039..180e9f17773 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -114,9 +114,8 @@ class GotOcr2VisionAttention(nn.Module): return rel_pos_resized[relative_coords.long()] - def add_decomposed_rel_pos( + def get_decomposed_rel_pos( self, - attn: torch.Tensor, query: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, @@ -128,8 +127,6 @@ class GotOcr2VisionAttention(nn.Module): https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py Args: - attn (`torch.Tensor`): - attention map. query (`torch.Tensor`): query q in the attention layer with shape (batch_size, query_height * query_width, channel). rel_pos_h (`torch.Tensor`): @@ -142,8 +139,8 @@ class GotOcr2VisionAttention(nn.Module): spatial sequence size of key k with (key_height, key_width). Returns: - attn (`torch.Tensor`): - attention map with added relative positional embeddings. + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. """ query_height, query_width = q_size key_height, key_width = k_size @@ -154,10 +151,10 @@ class GotOcr2VisionAttention(nn.Module): reshaped_query = query.reshape(batch_size, query_height, query_width, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) - attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) - return attn + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape @@ -173,9 +170,11 @@ class GotOcr2VisionAttention(nn.Module): attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos( - attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 595c822940c..903a0cbf134 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -1381,7 +1381,7 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): continue btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if input_mask[i, 0] == 0 else [0] - counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] out.append({"size": [height, width], "counts": counts}) return out @@ -1401,7 +1401,7 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"): # Encode run length out = [] for i in range(batch_size): - cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1 if len(cur_idxs) == 0: # No changes => either all 0 or all 1 # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. @@ -1412,7 +1412,9 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"): continue btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if input_mask[i, 0] == 0 else [0] - counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] + counts += ( + [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()] + ) out.append({"size": [height, width], "counts": counts}) return out diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 9fe4c27bcac..7fdb7a81dd4 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -820,9 +820,8 @@ class SamVisionAttention(nn.Module): return rel_pos_resized[relative_coords.long()] - def add_decomposed_rel_pos( + def get_decomposed_rel_pos( self, - attn: torch.Tensor, query: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, @@ -834,8 +833,6 @@ class SamVisionAttention(nn.Module): https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py Args: - attn (`torch.Tensor`): - attention map. query (`torch.Tensor`): query q in the attention layer with shape (batch_size, query_height * query_width, channel). rel_pos_h (`torch.Tensor`): @@ -848,8 +845,8 @@ class SamVisionAttention(nn.Module): spatial sequence size of key k with (key_height, key_width). Returns: - attn (`torch.Tensor`): - attention map with added relative positional embeddings. + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. """ query_height, query_width = q_size key_height, key_width = k_size @@ -860,10 +857,10 @@ class SamVisionAttention(nn.Module): reshaped_query = query.reshape(batch_size, query_height, query_width, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) - attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) - return attn + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape @@ -879,9 +876,11 @@ class SamVisionAttention(nn.Module): attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos( - attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) @@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention): def __init__(self, config, window_size): super().__init__(config, window_size) - def add_decomposed_rel_pos( - self, - query: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], - ) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 - This method is reimplemented to follow the implementation in: - https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950 - This implementation is more memory efficient when using SDPA in the forward method. - Args: - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - query_height, query_width = q_size - key_height, key_width = k_size - relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) - relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) - - batch_size, _, dim = query.shape - reshaped_query = query.reshape(batch_size, query_height, query_width, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) - rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - rel_h = rel_h.unsqueeze(-1) - rel_w = rel_w.unsqueeze(-2) - rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1) - rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width) - - return rel_h, rel_w - def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, B, nHead, H * W, C) qkv = ( @@ -960,25 +931,21 @@ class SamVisionSdpaAttention(SamVisionAttention): # q, k, v with shape (B * nHead, H * W, C) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) - rel_h, rel_w = None, None + attn_bias = None if self.use_rel_pos: - rel_h, rel_w = self.add_decomposed_rel_pos( + decomposed_rel_pos = self.get_decomposed_rel_pos( query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) + decomposed_rel_pos = decomposed_rel_pos.reshape( + batch_size, self.num_attention_heads, height * width, height * width + ) + attn_bias = decomposed_rel_pos query = query.view(batch_size, self.num_attention_heads, height * width, -1) key = key.view(batch_size, self.num_attention_heads, height * width, -1) value = value.view(batch_size, self.num_attention_heads, height * width, -1) - if self.use_rel_pos: - rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) - rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) - attn_bias = (rel_h + rel_w).view( - batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4) - ) - attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value) + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) attn_output = ( attn_output.view(batch_size, self.num_attention_heads, height, width, -1) @@ -988,17 +955,7 @@ class SamVisionSdpaAttention(SamVisionAttention): attn_output = self.proj(attn_output) - if output_attentions: - # For output_attentions, calculate the attention weights - attn_weights = (query @ key.transpose(-2, -1)) * self.scale - if attn_bias is not None: - attn_weights = attn_weights + attn_bias - attn_weights = F.softmax(attn_weights, dim=-1) - outputs = (attn_output, attn_weights) - else: - outputs = (attn_output, None) - - return outputs + return attn_output, None SAM_VISION_ATTENTION_CLASSES = { diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index ee75b1bf4f2..29a2335cb12 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -982,9 +982,8 @@ class TFSamVisionAttention(keras.layers.Layer): return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) - def add_decomposed_rel_pos( + def get_decomposed_rel_pos( self, - attn: tf.Tensor, query: tf.Tensor, rel_pos_h: tf.Tensor, rel_pos_w: tf.Tensor, @@ -996,8 +995,6 @@ class TFSamVisionAttention(keras.layers.Layer): https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py Args: - attn (`tf.Tensor`): - attention map. query (`tf.Tensor`): query q in the attention layer with shape (batch_size, query_height * query_width, channel). rel_pos_h (`tf.Tensor`): @@ -1010,8 +1007,8 @@ class TFSamVisionAttention(keras.layers.Layer): spatial sequence size of key k with (key_height, key_width). Returns: - attn (`tf.Tensor`): - attention map with added relative positional embeddings. + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. """ query_height, query_width = q_size key_height, key_width = k_size @@ -1022,10 +1019,12 @@ class TFSamVisionAttention(keras.layers.Layer): reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width)) - attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2) - attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width)) - return attn + + rel_h = tf.expand_dims(rel_h, axis=-1) + rel_w = tf.expand_dims(rel_w, axis=-2) + decomposed_rel_pos = rel_h + rel_w + + return decomposed_rel_pos def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: batch_size, height, width, _ = shape_list(hidden_states) @@ -1039,9 +1038,11 @@ class TFSamVisionAttention(keras.layers.Layer): attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos( - attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) + decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights)) + attn_weights = attn_weights + decomposed_rel_pos attn_weights = tf.nn.softmax(attn_weights, axis=-1) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 9872887fe58..c85abdc739e 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -979,7 +979,7 @@ class ProcessorMixin(PushToHubMixin): kwarg_value = kwargs.get(modality_key, "__empty__") else: kwarg_value = "__empty__" - if kwarg_value != "__empty__": + if not isinstance(kwarg_value, str) or kwarg_value != "__empty__": output_kwargs[modality][modality_key] = kwarg_value used_keys.add(modality_key) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 8c6cf1ce39a..d621f542873 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -312,7 +312,7 @@ class TFSamProcessorTest(unittest.TestCase): # This is shape (1, 2, 2). # Flattened in Fortran order -> [0, 1, 1, 1]. # The RLE for [0,1,1,1] is [1, 3]. - input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64) + input_mask = tf.constant([[[0, 1], [1, 1]]], dtype=tf.int64) rle = _mask_to_rle_tf(input_mask) self.assertEqual(len(rle), 1)