mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF SAM shape flexibility fixes (#23842)
SAM shape flexibility fixes for compilation
This commit is contained in:
parent
af45ec0a16
commit
ac224dee90
@ -226,7 +226,8 @@ class TFSamAttention(tf.keras.layers.Layer):
|
||||
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
|
||||
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
|
||||
return tf.reshape(
|
||||
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
hidden_states,
|
||||
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
|
||||
)
|
||||
|
||||
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
|
||||
@ -509,7 +510,7 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
|
||||
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
|
||||
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
|
||||
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
|
||||
if sparse_prompt_embeddings.shape[1] != 0:
|
||||
if shape_list(sparse_prompt_embeddings)[1] != 0:
|
||||
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
|
||||
else:
|
||||
tokens = output_tokens
|
||||
@ -695,8 +696,8 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
|
||||
target_labels_shape = (points.shape[0], points.shape[1], 1)
|
||||
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
|
||||
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
|
||||
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
|
||||
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
|
||||
points = tf.concat([points, padding_point], axis=2)
|
||||
@ -722,12 +723,12 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
|
||||
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
batch_size, nb_boxes = boxes.shape[:2]
|
||||
batch_size, nb_boxes = shape_list(boxes)[:2]
|
||||
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
|
||||
input_shape = (self.input_image_size, self.input_image_size)
|
||||
corner_embedding = self.shared_embedding(coords, input_shape)
|
||||
corner_embedding += tf.where(
|
||||
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0,
|
||||
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
|
||||
self.point_embed[2][0],
|
||||
self.point_embed[3][0],
|
||||
)
|
||||
@ -754,7 +755,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
|
||||
"""
|
||||
sparse_embeddings = None
|
||||
if input_points is not None:
|
||||
batch_size, point_batch_size = input_points.shape[:2]
|
||||
batch_size, point_batch_size = shape_list(input_points)[:2]
|
||||
if input_labels is None:
|
||||
raise ValueError("If points are provided, labels must also be provided.")
|
||||
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
||||
@ -763,7 +764,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
|
||||
)
|
||||
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
|
||||
if input_boxes is not None:
|
||||
batch_size = input_boxes.shape[0]
|
||||
batch_size = shape_list(input_boxes)[0]
|
||||
box_embeddings = self._embed_boxes(input_boxes)
|
||||
if sparse_embeddings is None:
|
||||
sparse_embeddings = box_embeddings
|
||||
@ -1376,8 +1377,8 @@ class TFSamModel(TFSamPreTrainedModel):
|
||||
" got {}.".format(input_boxes.shape),
|
||||
)
|
||||
if input_points is not None and input_boxes is not None:
|
||||
point_batch_size = input_points.shape[1]
|
||||
box_batch_size = input_boxes.shape[1]
|
||||
point_batch_size = shape_list(input_points)[1]
|
||||
box_batch_size = shape_list(input_boxes)[1]
|
||||
if point_batch_size != box_batch_size:
|
||||
raise ValueError(
|
||||
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
|
||||
|
Loading…
Reference in New Issue
Block a user