diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 7b3f876212b..f17ed046868 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): # Scatters the penalties token_penalties = tf.ones(logits.shape) + batch_size = input_ids.shape[0] + seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape indexable_prev_input_ids = tf.concat( ( - tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1), + tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1), tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1), ), axis=1,