Fix beam score calculation issue for JAX version (#27816)

* Fix beam score calculation issue for JAX

* Fix abstract tracer value errors
This commit is contained in:
Xin Qiu 2023-12-07 13:34:18 +08:00 committed by GitHub
parent 9660e27cd0
commit 7fc80724da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -809,6 +809,9 @@ class FlaxGenerationMixin:
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
# record the prompt length of decoder
decoder_prompt_len = input_ids.shape[-1]
# per batch,beam-item holding current token in loop.
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
@ -861,9 +864,13 @@ class FlaxGenerationMixin:
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = state.running_scores[:, :1] / (max_length**length_penalty)
best_running_score = state.running_scores[:, :1] / (
(max_length - decoder_prompt_len) ** length_penalty
)
else:
best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty)
best_running_score = state.running_scores[:, :1] / (
(state.cur_len - decoder_prompt_len) ** length_penalty
)
worst_finished_score = jnp.where(
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
)
@ -953,7 +960,7 @@ class FlaxGenerationMixin:
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty)
beams_in_batch_are_full = jnp.broadcast_to(
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
) & (early_stopping is True)
@ -990,9 +997,10 @@ class FlaxGenerationMixin:
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
if input_ids.shape[-1] > 1:
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
# Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn`
# when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when
# the very first prompt has sequence length > 1.
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
if not trace:
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)