🚨🚨🚨 Fix sdpa in SAM and refactor relative position embeddings (#36422)

* fall back to eager if output_attentions

* improve relative position embeddings

* run modular on got_ocr2

* run-slow: sam

* fix run-length encoding

* fix tf processor errors

* update tf_sam

* fix compile error

* re-run tests
This commit is contained in:
Armaghan Shakir 2025-03-17 14:39:52 +05:00 committed by GitHub
parent fc8764c9a6
commit c53d53da89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 62 additions and 103 deletions

View File

@ -114,9 +114,8 @@ class GotOcr2VisionAttention(nn.Module):
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
@ -128,8 +127,6 @@ class GotOcr2VisionAttention(nn.Module):
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
@ -142,8 +139,8 @@ class GotOcr2VisionAttention(nn.Module):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
@ -154,10 +151,10 @@ class GotOcr2VisionAttention(nn.Module):
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
return decomposed_rel_pos
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
@ -173,9 +170,11 @@ class GotOcr2VisionAttention(nn.Module):
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)

View File

@ -1381,7 +1381,7 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
out.append({"size": [height, width], "counts": counts})
return out
@ -1401,7 +1401,7 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
if len(cur_idxs) == 0:
# No changes => either all 0 or all 1
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
@ -1412,7 +1412,9 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
counts += (
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
)
out.append({"size": [height, width], "counts": counts})
return out

View File

@ -820,9 +820,8 @@ class SamVisionAttention(nn.Module):
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
@ -834,8 +833,6 @@ class SamVisionAttention(nn.Module):
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
@ -848,8 +845,8 @@ class SamVisionAttention(nn.Module):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
@ -860,10 +857,10 @@ class SamVisionAttention(nn.Module):
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
return decomposed_rel_pos
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
@ -879,9 +876,11 @@ class SamVisionAttention(nn.Module):
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention):
def __init__(self, config, window_size):
super().__init__(config, window_size)
def add_decomposed_rel_pos(
self,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
This method is reimplemented to follow the implementation in:
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
This implementation is more memory efficient when using SDPA in the forward method.
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
return rel_h, rel_w
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
if output_attentions:
logger.warning_once(
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
output_attentions=output_attentions,
)
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
@ -960,25 +931,21 @@ class SamVisionSdpaAttention(SamVisionAttention):
# q, k, v with shape (B * nHead, H * W, C)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
rel_h, rel_w = None, None
attn_bias = None
if self.use_rel_pos:
rel_h, rel_w = self.add_decomposed_rel_pos(
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape(
batch_size, self.num_attention_heads, height * width, height * width
)
attn_bias = decomposed_rel_pos
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
if self.use_rel_pos:
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
attn_output = (
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
@ -988,17 +955,7 @@ class SamVisionSdpaAttention(SamVisionAttention):
attn_output = self.proj(attn_output)
if output_attentions:
# For output_attentions, calculate the attention weights
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
return attn_output, None
SAM_VISION_ATTENTION_CLASSES = {

View File

@ -982,9 +982,8 @@ class TFSamVisionAttention(keras.layers.Layer):
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: tf.Tensor,
query: tf.Tensor,
rel_pos_h: tf.Tensor,
rel_pos_w: tf.Tensor,
@ -996,8 +995,6 @@ class TFSamVisionAttention(keras.layers.Layer):
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`tf.Tensor`):
attention map.
query (`tf.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`tf.Tensor`):
@ -1010,8 +1007,8 @@ class TFSamVisionAttention(keras.layers.Layer):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`tf.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
@ -1022,10 +1019,12 @@ class TFSamVisionAttention(keras.layers.Layer):
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
return attn
rel_h = tf.expand_dims(rel_h, axis=-1)
rel_w = tf.expand_dims(rel_w, axis=-2)
decomposed_rel_pos = rel_h + rel_w
return decomposed_rel_pos
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
batch_size, height, width, _ = shape_list(hidden_states)
@ -1039,9 +1038,11 @@ class TFSamVisionAttention(keras.layers.Layer):
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
attn_weights = attn_weights + decomposed_rel_pos
attn_weights = tf.nn.softmax(attn_weights, axis=-1)

View File

@ -979,7 +979,7 @@ class ProcessorMixin(PushToHubMixin):
kwarg_value = kwargs.get(modality_key, "__empty__")
else:
kwarg_value = "__empty__"
if kwarg_value != "__empty__":
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
output_kwargs[modality][modality_key] = kwarg_value
used_keys.add(modality_key)

View File

@ -312,7 +312,7 @@ class TFSamProcessorTest(unittest.TestCase):
# This is shape (1, 2, 2).
# Flattened in Fortran order -> [0, 1, 1, 1].
# The RLE for [0,1,1,1] is [1, 3].
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
input_mask = tf.constant([[[0, 1], [1, 1]]], dtype=tf.int64)
rle = _mask_to_rle_tf(input_mask)
self.assertEqual(len(rle), 1)