From a1c4b63076ed11946a24a3d2e3ab7d7e77819546 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 22 Jun 2023 15:49:13 +0100 Subject: [PATCH] TF CI fix for Segformer (#24426) Fix segformer so compilation can figure out the channel dim --- .../models/segformer/modeling_tf_segformer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py index b3090135afc..632382f95ed 100644 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ b/src/transformers/models/segformer/modeling_tf_segformer.py @@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel): self.config = config def call(self, encoder_hidden_states, training: bool = False): - batch_size = shape_list(encoder_hidden_states[-1])[0] - all_hidden_states = () for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps): if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3: height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32)) height = width = tf.cast(height, tf.int32) - encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) + channel_dim = shape_list(encoder_hidden_state)[-1] + encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) # unify channel dimension encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1]) - height = shape_list(encoder_hidden_state)[1] - width = shape_list(encoder_hidden_state)[2] + height, width = shape_list(encoder_hidden_state)[1:3] encoder_hidden_state = mlp(encoder_hidden_state) - encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) + channel_dim = shape_list(encoder_hidden_state)[-1] + encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) # upsample temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])