Fix beam score calculation issue for Tensorflow version (#27814)

* Fix beam score calculation issue for tensorflow version

* fix transition score computation error

* make cur_len represent the entire sequence length including decoder prompt
This commit is contained in:
Xin Qiu 2023-12-08 21:10:13 +08:00 committed by GitHub
parent 4c5ed1d0c9
commit 3ac9945e56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2268,6 +2268,8 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = shape_list(input_ids)
# store the prompt length of decoder
decoder_prompt_len = cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
@ -2286,8 +2288,8 @@ class TFGenerationMixin:
scores = tf.ones((batch_size, num_beams)) * -1.0e9
# per batch beam indices
running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
# flatten beam dim
if "encoder_outputs" in model_kwargs:
@ -2308,6 +2310,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
):
"""
@ -2318,15 +2321,17 @@ class TFGenerationMixin:
not_max_length_yet = cur_len < max_length
# 2. can the new beams still improve?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty)
else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
best_running_score = running_scores[:, :1] / (
tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty
)
worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
@ -2346,6 +2351,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
):
"""
@ -2387,7 +2393,9 @@ class TFGenerationMixin:
if output_scores:
all_scores.append(
logits_warper(
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs_processed), cur_len
flatten_beam_dim(running_sequences),
flatten_beam_dim(log_probs_processed),
cur_len,
)
)
if output_attentions and self.config.is_encoder_decoder:
@ -2439,6 +2447,14 @@ class TFGenerationMixin:
batch_modified_indices = topk_current_beam_indices + tf.broadcast_to(
tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape
)
update_indices = tf.stack(
[
indices_batch,
indices_beam,
tf.broadcast_to(cur_len - decoder_prompt_len, [batch_size * beams_to_keep]),
],
axis=-1,
)
topk_beam_indices = tf.tensor_scatter_nd_update(
tensor=topk_running_beam_indices,
indices=update_indices,
@ -2455,7 +2471,8 @@ class TFGenerationMixin:
eos_in_next_token = tf.math.reduce_any(
tf.equal(
tf.broadcast_to(
topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape
topk_sequences[:, :, cur_len],
[len(eos_token_id)] + topk_sequences[:, :, cur_len].shape,
),
tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1),
),
@ -2483,7 +2500,9 @@ class TFGenerationMixin:
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
topk_log_probs = topk_log_probs / (
tf.cast(cur_len + 1 - decoder_prompt_len, dtype=tf.float32) ** length_penalty
)
beams_in_batch_are_full = tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) & (early_stopping is True)
@ -2546,6 +2565,7 @@ class TFGenerationMixin:
next_scores,
next_beam_indices,
next_is_sent_finished,
decoder_prompt_len,
next_model_kwargs,
)
@ -2560,6 +2580,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
) = beam_search_body_fn(
cur_len,
@ -2570,6 +2591,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
)
@ -2585,6 +2607,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
_,
) = tf.while_loop(
beam_search_cond_fn,
@ -2598,6 +2621,7 @@ class TFGenerationMixin:
scores,
beam_indices,
is_sent_finished,
decoder_prompt_len,
model_kwargs,
),
maximum_iterations=maximum_iterations,
@ -2611,7 +2635,7 @@ class TFGenerationMixin:
beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices)
# Apply the length penalty so that running scores match the finalized scores if they are used
running_scores = running_scores / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
running_scores = running_scores / (tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty)
scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in descending order)
@ -2622,7 +2646,7 @@ class TFGenerationMixin:
if not use_xla:
# Cut for backward compatibility
sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, :cur_len]
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder: