mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
TF CI fix for Segformer (#24426)
Fix segformer so compilation can figure out the channel dim
This commit is contained in:
parent
754f61ca05
commit
a1c4b63076
@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def call(self, encoder_hidden_states, training: bool = False):
|
def call(self, encoder_hidden_states, training: bool = False):
|
||||||
batch_size = shape_list(encoder_hidden_states[-1])[0]
|
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
|
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
|
||||||
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
|
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
|
||||||
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
|
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
|
||||||
height = width = tf.cast(height, tf.int32)
|
height = width = tf.cast(height, tf.int32)
|
||||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
channel_dim = shape_list(encoder_hidden_state)[-1]
|
||||||
|
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
||||||
|
|
||||||
# unify channel dimension
|
# unify channel dimension
|
||||||
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
|
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
|
||||||
height = shape_list(encoder_hidden_state)[1]
|
height, width = shape_list(encoder_hidden_state)[1:3]
|
||||||
width = shape_list(encoder_hidden_state)[2]
|
|
||||||
encoder_hidden_state = mlp(encoder_hidden_state)
|
encoder_hidden_state = mlp(encoder_hidden_state)
|
||||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
channel_dim = shape_list(encoder_hidden_state)[-1]
|
||||||
|
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
|
||||||
|
|
||||||
# upsample
|
# upsample
|
||||||
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
|
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user