mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
mean does not exist in TF2
This commit is contained in:
parent
f3386d9383
commit
3de31f8d28
@ -454,7 +454,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
||||
elif self.summary_type == 'first':
|
||||
output = hidden_states[:, 0]
|
||||
elif self.summary_type == 'mean':
|
||||
output = tf.mean(hidden_states, axis=1)
|
||||
output = tf.reduce_mean(hidden_states, axis=1)
|
||||
elif self.summary_type == 'cls_index':
|
||||
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
|
||||
if cls_index is None:
|
||||
|
Loading…
Reference in New Issue
Block a user