TF: Fix generation repetition penalty with XLA (#18648)

This commit is contained in:
Joao Gante 2022-08-16 13:30:52 +01:00 committed by GitHub
parent 81ab11124f
commit fd9aa82b07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
# Scatters the penalties
token_penalties = tf.ones(logits.shape)
batch_size = input_ids.shape[0]
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
indexable_prev_input_ids = tf.concat(
(
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
),
axis=1,