mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix beam score calculation issue for JAX version (#27816)
* Fix beam score calculation issue for JAX * Fix abstract tracer value errors
This commit is contained in:
parent
9660e27cd0
commit
7fc80724da
@ -809,6 +809,9 @@ class FlaxGenerationMixin:
|
||||
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# record the prompt length of decoder
|
||||
decoder_prompt_len = input_ids.shape[-1]
|
||||
|
||||
# per batch,beam-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
@ -861,9 +864,13 @@ class FlaxGenerationMixin:
|
||||
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
|
||||
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_running_score = state.running_scores[:, :1] / (max_length**length_penalty)
|
||||
best_running_score = state.running_scores[:, :1] / (
|
||||
(max_length - decoder_prompt_len) ** length_penalty
|
||||
)
|
||||
else:
|
||||
best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty)
|
||||
best_running_score = state.running_scores[:, :1] / (
|
||||
(state.cur_len - decoder_prompt_len) ** length_penalty
|
||||
)
|
||||
worst_finished_score = jnp.where(
|
||||
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||
)
|
||||
@ -953,7 +960,7 @@ class FlaxGenerationMixin:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
||||
topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty)
|
||||
beams_in_batch_are_full = jnp.broadcast_to(
|
||||
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
|
||||
) & (early_stopping is True)
|
||||
@ -990,9 +997,10 @@ class FlaxGenerationMixin:
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[-1] > 1:
|
||||
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
||||
# Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn`
|
||||
# when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when
|
||||
# the very first prompt has sequence length > 1.
|
||||
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
|
Loading…
Reference in New Issue
Block a user