mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix: sampling in flax keeps EOS (#28378)
This commit is contained in:
parent
7e0ddf89f4
commit
735968b61c
@ -716,8 +716,8 @@ class FlaxGenerationMixin:
|
||||
|
||||
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
||||
|
||||
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
|
Loading…
Reference in New Issue
Block a user