mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
🚨🚨🚨 Fix sdpa in SAM and refactor relative position embeddings (#36422)
* fall back to eager if output_attentions * improve relative position embeddings * run modular on got_ocr2 * run-slow: sam * fix run-length encoding * fix tf processor errors * update tf_sam * fix compile error * re-run tests
This commit is contained in:
parent
fc8764c9a6
commit
c53d53da89
@ -114,9 +114,8 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
def get_decomposed_rel_pos(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
@ -128,8 +127,6 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (`torch.Tensor`):
|
||||
attention map.
|
||||
query (`torch.Tensor`):
|
||||
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
||||
rel_pos_h (`torch.Tensor`):
|
||||
@ -142,8 +139,8 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
spatial sequence size of key k with (key_height, key_width).
|
||||
|
||||
Returns:
|
||||
attn (`torch.Tensor`):
|
||||
attention map with added relative positional embeddings.
|
||||
decomposed_rel_pos (`torch.Tensor`):
|
||||
decomposed relative position embeddings.
|
||||
"""
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
@ -154,10 +151,10 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
|
||||
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
|
||||
return attn
|
||||
|
||||
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
|
||||
return decomposed_rel_pos
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
@ -173,9 +170,11 @@ class GotOcr2VisionAttention(nn.Module):
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
||||
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
||||
attn_weights = attn_weights + decomposed_rel_pos
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
|
||||
|
@ -1381,7 +1381,7 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||
continue
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
|
||||
out.append({"size": [height, width], "counts": counts})
|
||||
return out
|
||||
|
||||
@ -1401,7 +1401,7 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
|
||||
if len(cur_idxs) == 0:
|
||||
# No changes => either all 0 or all 1
|
||||
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
||||
@ -1412,7 +1412,9 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
||||
continue
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
counts += (
|
||||
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
|
||||
)
|
||||
out.append({"size": [height, width], "counts": counts})
|
||||
return out
|
||||
|
||||
|
@ -820,9 +820,8 @@ class SamVisionAttention(nn.Module):
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
def get_decomposed_rel_pos(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
@ -834,8 +833,6 @@ class SamVisionAttention(nn.Module):
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (`torch.Tensor`):
|
||||
attention map.
|
||||
query (`torch.Tensor`):
|
||||
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
||||
rel_pos_h (`torch.Tensor`):
|
||||
@ -848,8 +845,8 @@ class SamVisionAttention(nn.Module):
|
||||
spatial sequence size of key k with (key_height, key_width).
|
||||
|
||||
Returns:
|
||||
attn (`torch.Tensor`):
|
||||
attention map with added relative positional embeddings.
|
||||
decomposed_rel_pos (`torch.Tensor`):
|
||||
decomposed relative position embeddings.
|
||||
"""
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
@ -860,10 +857,10 @@ class SamVisionAttention(nn.Module):
|
||||
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
|
||||
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
|
||||
return attn
|
||||
|
||||
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
|
||||
return decomposed_rel_pos
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
@ -879,9 +876,11 @@ class SamVisionAttention(nn.Module):
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
||||
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
||||
attn_weights = attn_weights + decomposed_rel_pos
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
|
||||
@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention):
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__(config, window_size)
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
This method is reimplemented to follow the implementation in:
|
||||
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
|
||||
This implementation is more memory efficient when using SDPA in the forward method.
|
||||
Args:
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
||||
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
||||
|
||||
batch_size, _, dim = query.shape
|
||||
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
rel_h = rel_h.unsqueeze(-1)
|
||||
rel_w = rel_w.unsqueeze(-2)
|
||||
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
|
||||
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
|
||||
|
||||
return rel_h, rel_w
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = (
|
||||
@ -960,25 +931,21 @@ class SamVisionSdpaAttention(SamVisionAttention):
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
||||
|
||||
rel_h, rel_w = None, None
|
||||
attn_bias = None
|
||||
if self.use_rel_pos:
|
||||
rel_h, rel_w = self.add_decomposed_rel_pos(
|
||||
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
||||
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
decomposed_rel_pos = decomposed_rel_pos.reshape(
|
||||
batch_size, self.num_attention_heads, height * width, height * width
|
||||
)
|
||||
attn_bias = decomposed_rel_pos
|
||||
|
||||
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
|
||||
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
|
||||
attn_bias = (rel_h + rel_w).view(
|
||||
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
|
||||
)
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
|
||||
else:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
|
||||
|
||||
attn_output = (
|
||||
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
|
||||
@ -988,17 +955,7 @@ class SamVisionSdpaAttention(SamVisionAttention):
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
# For output_attentions, calculate the attention weights
|
||||
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
|
||||
if attn_bias is not None:
|
||||
attn_weights = attn_weights + attn_bias
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
outputs = (attn_output, attn_weights)
|
||||
else:
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
return attn_output, None
|
||||
|
||||
|
||||
SAM_VISION_ATTENTION_CLASSES = {
|
||||
|
@ -982,9 +982,8 @@ class TFSamVisionAttention(keras.layers.Layer):
|
||||
|
||||
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
def get_decomposed_rel_pos(
|
||||
self,
|
||||
attn: tf.Tensor,
|
||||
query: tf.Tensor,
|
||||
rel_pos_h: tf.Tensor,
|
||||
rel_pos_w: tf.Tensor,
|
||||
@ -996,8 +995,6 @@ class TFSamVisionAttention(keras.layers.Layer):
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (`tf.Tensor`):
|
||||
attention map.
|
||||
query (`tf.Tensor`):
|
||||
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
||||
rel_pos_h (`tf.Tensor`):
|
||||
@ -1010,8 +1007,8 @@ class TFSamVisionAttention(keras.layers.Layer):
|
||||
spatial sequence size of key k with (key_height, key_width).
|
||||
|
||||
Returns:
|
||||
attn (`tf.Tensor`):
|
||||
attention map with added relative positional embeddings.
|
||||
decomposed_rel_pos (`torch.Tensor`):
|
||||
decomposed relative position embeddings.
|
||||
"""
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
@ -1022,10 +1019,12 @@ class TFSamVisionAttention(keras.layers.Layer):
|
||||
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
|
||||
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
|
||||
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
|
||||
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
|
||||
return attn
|
||||
|
||||
rel_h = tf.expand_dims(rel_h, axis=-1)
|
||||
rel_w = tf.expand_dims(rel_w, axis=-2)
|
||||
decomposed_rel_pos = rel_h + rel_w
|
||||
|
||||
return decomposed_rel_pos
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
|
||||
batch_size, height, width, _ = shape_list(hidden_states)
|
||||
@ -1039,9 +1038,11 @@ class TFSamVisionAttention(keras.layers.Layer):
|
||||
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
||||
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
|
||||
attn_weights = attn_weights + decomposed_rel_pos
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
|
@ -979,7 +979,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
kwarg_value = kwargs.get(modality_key, "__empty__")
|
||||
else:
|
||||
kwarg_value = "__empty__"
|
||||
if kwarg_value != "__empty__":
|
||||
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
|
||||
output_kwargs[modality][modality_key] = kwarg_value
|
||||
used_keys.add(modality_key)
|
||||
|
||||
|
@ -312,7 +312,7 @@ class TFSamProcessorTest(unittest.TestCase):
|
||||
# This is shape (1, 2, 2).
|
||||
# Flattened in Fortran order -> [0, 1, 1, 1].
|
||||
# The RLE for [0,1,1,1] is [1, 3].
|
||||
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
|
||||
input_mask = tf.constant([[[0, 1], [1, 1]]], dtype=tf.int64)
|
||||
rle = _mask_to_rle_tf(input_mask)
|
||||
|
||||
self.assertEqual(len(rle), 1)
|
||||
|
Loading…
Reference in New Issue
Block a user