diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 794587cdde0..46710b32984 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -226,7 +226,8 @@ class TFSamAttention(tf.keras.layers.Layer): batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) return tf.reshape( - hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) + hidden_states, + (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head), ) def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: @@ -509,7 +510,7 @@ class TFSamMaskDecoder(tf.keras.layers.Layer): # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. - if sparse_prompt_embeddings.shape[1] != 0: + if shape_list(sparse_prompt_embeddings)[1] != 0: tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) else: tokens = output_tokens @@ -695,8 +696,8 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: - target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) - target_labels_shape = (points.shape[0], points.shape[1], 1) + target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1]) + target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1) padding_point = tf.zeros(target_point_shape, dtype=points.dtype) padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) points = tf.concat([points, padding_point], axis=2) @@ -722,12 +723,12 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel - batch_size, nb_boxes = boxes.shape[:2] + batch_size, nb_boxes = shape_list(boxes)[:2] coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) input_shape = (self.input_image_size, self.input_image_size) corner_embedding = self.shared_embedding(coords, input_shape) corner_embedding += tf.where( - tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, + tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0, self.point_embed[2][0], self.point_embed[3][0], ) @@ -754,7 +755,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): """ sparse_embeddings = None if input_points is not None: - batch_size, point_batch_size = input_points.shape[:2] + batch_size, point_batch_size = shape_list(input_points)[:2] if input_labels is None: raise ValueError("If points are provided, labels must also be provided.") point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) @@ -763,7 +764,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): ) sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) if input_boxes is not None: - batch_size = input_boxes.shape[0] + batch_size = shape_list(input_boxes)[0] box_embeddings = self._embed_boxes(input_boxes) if sparse_embeddings is None: sparse_embeddings = box_embeddings @@ -1376,8 +1377,8 @@ class TFSamModel(TFSamPreTrainedModel): " got {}.".format(input_boxes.shape), ) if input_points is not None and input_boxes is not None: - point_batch_size = input_points.shape[1] - box_batch_size = input_boxes.shape[1] + point_batch_size = shape_list(input_points)[1] + box_batch_size = shape_list(input_boxes)[1] if point_batch_size != box_batch_size: raise ValueError( "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(