Return the permuted hidden states if return_dict=True (#18578)

This commit is contained in:
amyeroberts 2022-08-11 17:32:11 +01:00 committed by GitHub
parent f28f240828
commit c8b6ae858d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -330,7 +330,8 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
hidden_states = hidden_states if output_hidden_states else ()
return (last_hidden_state, pooled_output) + hidden_states
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,