rename variable

This commit is contained in:
Patrick von Platen 2020-03-09 20:25:09 +01:00
parent cf06290565
commit 10989715d0

View File

@ -990,7 +990,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]