fix: sampling in flax keeps EOS (#28378)

This commit is contained in:
Boris Dayma 2024-01-15 11:12:09 -07:00 committed by GitHub
parent 7e0ddf89f4
commit 735968b61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -716,8 +716,8 @@ class FlaxGenerationMixin:
next_token = jax.random.categorical(prng_key, logits, axis=-1)
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))