mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
correct beam search sampling
This commit is contained in:
parent
c4c4c9998a
commit
7a89a3e493
@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))
|
||||
if do_sample is False:
|
||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
|
||||
else:
|
||||
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||
|
||||
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
|
||||
|
||||
# cache compute states
|
||||
past = None
|
||||
@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + tf.broadcast_to(
|
||||
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = tf_top_k_top_p_filtering(
|
||||
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
_scores = tf_top_k_top_p_filtering(
|
||||
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
|
||||
|
||||
next_tokens = tf.random.categorical(
|
||||
next_token_logits, dtype=tf.int32, num_samples=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
_scores, dtype=tf.int32, num_samples=2 * num_beams
|
||||
) # (batch_size, 2 * num_beams)
|
||||
# Compute next scores
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
|
||||
next_scores = _scores + tf.broadcast_to(
|
||||
beam_scores[:, None], (batch_size * num_beams, 2)
|
||||
) # (batch_size * num_beams, 2)
|
||||
# Match shape of greedy beam search
|
||||
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
Loading…
Reference in New Issue
Block a user