mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF CI fix for Segformer (#24426)
Fix segformer so compilation can figure out the channel dim
This commit is contained in:
parent
754f61ca05
commit
a1c4b63076
@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
|
||||
self.config = config
|
||||
|
||||
def call(self, encoder_hidden_states, training: bool = False):
|
||||
batch_size = shape_list(encoder_hidden_states[-1])[0]
|
||||
|
||||
all_hidden_states = ()
|
||||
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
|
||||
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
|
||||
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
|
||||
height = width = tf.cast(height, tf.int32)
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
||||
channel_dim = shape_list(encoder_hidden_state)[-1]
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
||||
|
||||
# unify channel dimension
|
||||
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
|
||||
height = shape_list(encoder_hidden_state)[1]
|
||||
width = shape_list(encoder_hidden_state)[2]
|
||||
height, width = shape_list(encoder_hidden_state)[1:3]
|
||||
encoder_hidden_state = mlp(encoder_hidden_state)
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
||||
channel_dim = shape_list(encoder_hidden_state)[-1]
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
||||
|
||||
# upsample
|
||||
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user