mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: Fix generation repetition penalty with XLA (#18648)
This commit is contained in:
parent
81ab11124f
commit
fd9aa82b07
@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
|
||||
|
||||
# Scatters the penalties
|
||||
token_penalties = tf.ones(logits.shape)
|
||||
batch_size = input_ids.shape[0]
|
||||
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
|
||||
indexable_prev_input_ids = tf.concat(
|
||||
(
|
||||
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
|
||||
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
|
||||
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
|
||||
),
|
||||
axis=1,
|
||||
|
Loading…
Reference in New Issue
Block a user