correct beam search sampling

This commit is contained in:
Patrick von Platen 2020-03-04 12:02:57 +01:00
parent c4c4c9998a
commit 7a89a3e493

View File

@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
]
# scores for each sentence in the beam
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))
if do_sample is False:
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
else:
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states
past = None
@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
next_token_logits = tf_top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
_scores = tf_top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
next_tokens = tf.random.categorical(
next_token_logits, dtype=tf.int32, num_samples=2
) # (batch_size * num_beams, vocab_size)
_scores, dtype=tf.int32, num_samples=2 * num_beams
) # (batch_size, 2 * num_beams)
# Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, 2)
) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
else:
# do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)