Generate: multi-device support for contrastive search (#24635)

This commit is contained in:
Joao Gante 2023-07-03 16:08:20 +01:00 committed by GitHub
parent 4b26a61631
commit 9934bb1f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2060,8 +2060,10 @@ class GenerationMixin:
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
# introduce (noticeable) slowdowns on single-device runs.
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = selected_idx.to("cpu")
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores