mean does not exist in TF2

This commit is contained in:
Lysandre 2019-11-19 18:14:14 -05:00
parent f3386d9383
commit 3de31f8d28

View File

@ -454,7 +454,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'first':
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = tf.mean(hidden_states, axis=1)
output = tf.reduce_mean(hidden_states, axis=1)
elif self.summary_type == 'cls_index':
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
if cls_index is None: