From 3de31f8d287da44a40566fb1d5c44107708b87ea Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 19 Nov 2019 18:14:14 -0500 Subject: [PATCH] mean does not exist in TF2 --- transformers/modeling_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index e08605d1548..8be7eaaf677 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -454,7 +454,7 @@ class TFSequenceSummary(tf.keras.layers.Layer): elif self.summary_type == 'first': output = hidden_states[:, 0] elif self.summary_type == 'mean': - output = tf.mean(hidden_states, axis=1) + output = tf.reduce_mean(hidden_states, axis=1) elif self.summary_type == 'cls_index': hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] if cls_index is None: