diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 482e8209f3b..604eb0ffd55 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1835,7 +1835,7 @@ class GenerationMixin: next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) - next_indices = next_tokens // vocab_size + next_indices = (next_tokens / vocab_size).long() next_tokens = next_tokens % vocab_size # stateless