Fix TF bad words filter with XLA (#18286)

* Fix bad words filter in XLA generation

* Remove my cool debug breakpoints (again)
This commit is contained in:
Matt 2022-07-25 15:19:39 -04:00 committed by GitHub
parent f4e172716b
commit 45a1475462
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -332,7 +332,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
def _len_greater_than_cur_len():
# Otherwise, if the bad sequence is longer than the current length they can't ever match
return tf.cond(
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
lambda: tf.zeros((), dtype=tf.bool),
_match_found,
)