From 7562366d4be98152059e9f8e923bfb1bad600cb5 Mon Sep 17 00:00:00 2001 From: Aya <65711439+Ayaa17@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:44:09 +0800 Subject: [PATCH] fix: multilingual midel convert to tflite get wrong token (#32079) * fix: multilingual midel convert to tflite get wrong token * fix: modify test_force_tokens_logits_processor the checking value as scores.dtype.min --------- Co-authored-by: kent.sc.hung Co-authored-by: Aya <[kent831217@gmail.com]> --- src/transformers/generation/tf_logits_process.py | 2 +- tests/generation/test_tf_logits_process.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/tf_logits_process.py b/src/transformers/generation/tf_logits_process.py index 58824b7b007..91e20fe02f7 100644 --- a/src/transformers/generation/tf_logits_process.py +++ b/src/transformers/generation/tf_logits_process.py @@ -581,7 +581,7 @@ class TFForceTokensLogitsProcessor(TFLogitsProcessor): batch_size = scores.shape[0] current_token = self.force_token_array[generation_idx] - new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf") + new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min]) indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1) updates = tf.zeros((batch_size,), dtype=scores.dtype) new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates) diff --git a/tests/generation/test_tf_logits_process.py b/tests/generation/test_tf_logits_process.py index e87c843d9cb..f06f5695b1c 100644 --- a/tests/generation/test_tf_logits_process.py +++ b/tests/generation/test_tf_logits_process.py @@ -406,7 +406,12 @@ class TFLogitsProcessorTest(unittest.TestCase): non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]] self.assertTrue( - tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))), + tf.math.reduce_all( + tf.experimental.numpy.isclose( + tf.gather(scores, [non_forced_inds], axis=1), + tf.constant(scores.dtype.min), + ) + ) ) # check that if the cur_len is not contained in the force_token_map, the logits are not modified