TF CI fix for Segformer (#24426)

Fix segformer so compilation can figure out the channel dim
This commit is contained in:
Matt 2023-06-22 15:49:13 +01:00 committed by GitHub
parent 754f61ca05
commit a1c4b63076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
self.config = config self.config = config
def call(self, encoder_hidden_states, training: bool = False): def call(self, encoder_hidden_states, training: bool = False):
batch_size = shape_list(encoder_hidden_states[-1])[0]
all_hidden_states = () all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps): for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3: if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32)) height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
height = width = tf.cast(height, tf.int32) height = width = tf.cast(height, tf.int32)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# unify channel dimension # unify channel dimension
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1]) encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
height = shape_list(encoder_hidden_state)[1] height, width = shape_list(encoder_hidden_state)[1:3]
width = shape_list(encoder_hidden_state)[2]
encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# upsample # upsample
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1]) temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])