mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TF bad words filter with XLA (#18286)
* Fix bad words filter in XLA generation * Remove my cool debug breakpoints (again)
This commit is contained in:
parent
f4e172716b
commit
45a1475462
@ -332,7 +332,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
||||
def _len_greater_than_cur_len():
|
||||
# Otherwise, if the bad sequence is longer than the current length they can't ever match
|
||||
return tf.cond(
|
||||
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
|
||||
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
|
||||
lambda: tf.zeros((), dtype=tf.bool),
|
||||
_match_found,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user