TF SAM shape flexibility fixes (#23842)

SAM shape flexibility fixes for compilation
This commit is contained in:
Matt 2023-05-30 13:08:44 +01:00 committed by GitHub
parent af45ec0a16
commit ac224dee90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -226,7 +226,8 @@ class TFSamAttention(tf.keras.layers.Layer):
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head)
hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
)
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
@ -509,7 +510,7 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if sparse_prompt_embeddings.shape[1] != 0:
if shape_list(sparse_prompt_embeddings)[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else:
tokens = output_tokens
@ -695,8 +696,8 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
@ -722,12 +723,12 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2]
batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0,
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
@ -754,7 +755,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@ -763,7 +764,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = input_boxes.shape[0]
batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
@ -1376,8 +1377,8 @@ class TFSamModel(TFSamPreTrainedModel):
" got {}.".format(input_boxes.shape),
)
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
point_batch_size = shape_list(input_points)[1]
box_batch_size = shape_list(input_boxes)[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(