mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
refactored beam search according to torch implementation
This commit is contained in:
parent
c8035e11e8
commit
9362eb4a07
@ -557,6 +557,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
else:
|
||||
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
# not allow to duplicate outputs when greedy decoding
|
||||
if do_sample is False:
|
||||
if num_beams == 1:
|
||||
# no_beam_search greedy generation conditions
|
||||
@ -580,13 +581,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
cur_len = shape_list(input_ids)[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1 and do_sample:
|
||||
# Expand input to num return sequences
|
||||
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
|
||||
# set effective batch size and effective batch multiplier according to do_sample
|
||||
if do_sample:
|
||||
effective_batch_size = batch_size * num_return_sequences
|
||||
input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len))
|
||||
effective_batch_mult = num_return_sequences
|
||||
else:
|
||||
effective_batch_size = batch_size
|
||||
effective_batch_mult = 1
|
||||
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
input_ids_len = shape_list(input_ids)[-1]
|
||||
input_ids = tf.broadcast_to(
|
||||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||
)
|
||||
input_ids = tf.reshape(
|
||||
input_ids, (effective_batch_size * num_beams, input_ids_len)
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
|
||||
if num_beams > 1:
|
||||
output = self._generate_beam_search(
|
||||
@ -701,12 +712,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||
break
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||
min_sent_length = tf.math.reduce_min(sent_lengths)
|
||||
max_sent_length = tf.math.reduce_max(sent_lengths)
|
||||
@ -750,10 +761,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
|
||||
# Expand input to num beams
|
||||
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len))
|
||||
input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len)) # (batch_size * num_beams, cur_len)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
@ -768,7 +775,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||
|
||||
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
|
||||
|
||||
# cache compute states
|
||||
past = None
|
||||
|
||||
@ -813,6 +819,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
) # (batch_size, 2 * num_beams)
|
||||
# Compute next scores
|
||||
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
|
||||
|
||||
# sort the sampled vector to make sure that the first num_beams samples are the best
|
||||
next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
|
||||
next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
||||
next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
@ -826,6 +837,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
next_scores = tf.reshape(
|
||||
next_scores, (batch_size, num_beams * vocab_size)
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
@ -861,14 +873,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
||||
generated_hyps[batch_idx].add(
|
||||
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
|
||||
)
|
||||
generated_hyps[batch_idx].add(tf.identity(input_ids[effective_beam_id]), score.numpy())
|
||||
else:
|
||||
# add next predicted token if it is not eos_token
|
||||
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
|
||||
next_sent_beam.append((score, token_id, effective_beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
@ -893,24 +904,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# finalize all open beam hypotheses and end to generated hypotheses
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if not done[batch_idx]:
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
if done[batch_idx]:
|
||||
continue
|
||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||
if eos_token_ids is not None and all(
|
||||
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
||||
):
|
||||
assert tf.reduce_all(
|
||||
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
|
||||
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
||||
next_scores[:, :num_beams][batch_idx], tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
|
||||
)
|
||||
|
||||
# get beam and token IDs
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
generated_hyps[batch_idx].add(
|
||||
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
|
||||
)
|
||||
# need to add best num_beams hypotheses to generated hyps
|
||||
for beam_id in range(num_beams):
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
final_score = beam_scores[effective_beam_id].numpy().item()
|
||||
final_tokens = input_ids[effective_beam_id]
|
||||
generated_hyps[batch_idx].add(final_tokens, final_score)
|
||||
|
||||
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
||||
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
||||
|
Loading…
Reference in New Issue
Block a user