mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix beam score calculation issue for Tensorflow version (#27814)
* Fix beam score calculation issue for tensorflow version * fix transition score computation error * make cur_len represent the entire sequence length including decoder prompt
This commit is contained in:
parent
4c5ed1d0c9
commit
3ac9945e56
@ -2268,6 +2268,8 @@ class TFGenerationMixin:
|
||||
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
batch_size, num_beams, cur_len = shape_list(input_ids)
|
||||
# store the prompt length of decoder
|
||||
decoder_prompt_len = cur_len
|
||||
|
||||
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
|
||||
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
|
||||
@ -2286,8 +2288,8 @@ class TFGenerationMixin:
|
||||
scores = tf.ones((batch_size, num_beams)) * -1.0e9
|
||||
|
||||
# per batch beam indices
|
||||
running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
|
||||
beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
|
||||
running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
|
||||
beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
|
||||
|
||||
# flatten beam dim
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
@ -2308,6 +2310,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
model_kwargs,
|
||||
):
|
||||
"""
|
||||
@ -2318,15 +2321,17 @@ class TFGenerationMixin:
|
||||
not_max_length_yet = cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. See the discussion
|
||||
# below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
|
||||
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
|
||||
best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty)
|
||||
else:
|
||||
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||
best_running_score = running_scores[:, :1] / (
|
||||
tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty
|
||||
)
|
||||
worst_finished_score = tf.where(
|
||||
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
|
||||
)
|
||||
@ -2346,6 +2351,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
model_kwargs,
|
||||
):
|
||||
"""
|
||||
@ -2387,7 +2393,9 @@ class TFGenerationMixin:
|
||||
if output_scores:
|
||||
all_scores.append(
|
||||
logits_warper(
|
||||
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs_processed), cur_len
|
||||
flatten_beam_dim(running_sequences),
|
||||
flatten_beam_dim(log_probs_processed),
|
||||
cur_len,
|
||||
)
|
||||
)
|
||||
if output_attentions and self.config.is_encoder_decoder:
|
||||
@ -2439,6 +2447,14 @@ class TFGenerationMixin:
|
||||
batch_modified_indices = topk_current_beam_indices + tf.broadcast_to(
|
||||
tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape
|
||||
)
|
||||
update_indices = tf.stack(
|
||||
[
|
||||
indices_batch,
|
||||
indices_beam,
|
||||
tf.broadcast_to(cur_len - decoder_prompt_len, [batch_size * beams_to_keep]),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
topk_beam_indices = tf.tensor_scatter_nd_update(
|
||||
tensor=topk_running_beam_indices,
|
||||
indices=update_indices,
|
||||
@ -2455,7 +2471,8 @@ class TFGenerationMixin:
|
||||
eos_in_next_token = tf.math.reduce_any(
|
||||
tf.equal(
|
||||
tf.broadcast_to(
|
||||
topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape
|
||||
topk_sequences[:, :, cur_len],
|
||||
[len(eos_token_id)] + topk_sequences[:, :, cur_len].shape,
|
||||
),
|
||||
tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1),
|
||||
),
|
||||
@ -2483,7 +2500,9 @@ class TFGenerationMixin:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||
topk_log_probs = topk_log_probs / (
|
||||
tf.cast(cur_len + 1 - decoder_prompt_len, dtype=tf.float32) ** length_penalty
|
||||
)
|
||||
beams_in_batch_are_full = tf.broadcast_to(
|
||||
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
|
||||
) & (early_stopping is True)
|
||||
@ -2546,6 +2565,7 @@ class TFGenerationMixin:
|
||||
next_scores,
|
||||
next_beam_indices,
|
||||
next_is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
next_model_kwargs,
|
||||
)
|
||||
|
||||
@ -2560,6 +2580,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
model_kwargs,
|
||||
) = beam_search_body_fn(
|
||||
cur_len,
|
||||
@ -2570,6 +2591,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
model_kwargs,
|
||||
)
|
||||
|
||||
@ -2585,6 +2607,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
_,
|
||||
) = tf.while_loop(
|
||||
beam_search_cond_fn,
|
||||
@ -2598,6 +2621,7 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
decoder_prompt_len,
|
||||
model_kwargs,
|
||||
),
|
||||
maximum_iterations=maximum_iterations,
|
||||
@ -2611,7 +2635,7 @@ class TFGenerationMixin:
|
||||
beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices)
|
||||
|
||||
# Apply the length penalty so that running scores match the finalized scores if they are used
|
||||
running_scores = running_scores / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||
running_scores = running_scores / (tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty)
|
||||
scores = tf.where(none_finished[:, None], scores, running_scores)
|
||||
|
||||
# Take best beams for each batch (the score is sorted in descending order)
|
||||
@ -2622,7 +2646,7 @@ class TFGenerationMixin:
|
||||
if not use_xla:
|
||||
# Cut for backward compatibility
|
||||
sequences = sequences[:, :cur_len]
|
||||
beam_indices = beam_indices[:, :cur_len]
|
||||
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
|
Loading…
Reference in New Issue
Block a user